In [None]:
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.distributed import *
from superres import superres_metrics, ssim, psnr
from superres import RRDB_Net, MultiImageToMultiChannel, TwoXModel, TwoYLoss, MultiToMultiImageList, TransformableLists
import czifile
from superres.helpers import get_czi_shape_info, build_index
import imageio
import adabound

datasetname = 'multiframe_002'
data_path = Path('/scratch/bpho')
datasets = data_path/'datasets'
datasources = data_path/'datasources'
dataset = datasets/datasetname

hr_path = dataset/'hr'
lr_path = dataset/'lr'
lr_up_path = dataset/'lr_up'

In [None]:
from numbers import Integral

class MultiImage(ItemBase):
    def __init__(self, img_list):
        self.img_list = img_list

    def __repr__(self):
        return f'MultiImage: {[str(img) for img in self.img_list]}'

    @property
    def size(self):
        return [img.size for img in self.img_list]
    
    @property
    def data(self):
        img_data = torch.stack([img.data for img in self.img_list])
        num_img, c, h, w = img_data.shape
        data = tensor(img_data.view(num_img*c, h, w))
        return data

    def apply_tfms(self, tfms, **kwargs):
        first_time = True

        save_img_list = []
        for img in self.img_list:
            new_img = img.apply_tfms(tfms, do_resolve=first_time, **kwargs)
            first_time = False
            save_img_list.append(new_img)
        self.img_list = save_img_list
        return self
    
    def _repr_png_(self): return self._repr_image_format('png')
    def _repr_jpeg_(self): return self._repr_image_format('jpeg')

    def _repr_image_format(self, format_str):
        #return self.img_lists[0]._repr_image_format(format_str)
        with BytesIO() as str_buffer:
            img_data = np.concatenate([image2np(img.px) for img in self.img_list], axis=1)
            plt.imsave(str_buffer, img_data, format=format_str)
            return str_buffer.getvalue()

    def show(self, **kwargs):
        self.img_list[0].show(**kwargs)


        
        
        
def multi_normalize(x:TensorImage, mean:FloatTensor,std:FloatTensor)->TensorImage:
    "Normalize `x` with `mean` and `std`."
    return (x - mean) / std


def multi_denormalize(x:TensorImage, mean:FloatTensor,std:FloatTensor, do_x:bool=True)->TensorImage:
    "Denormalize `x` with `mean` and `std`."
    return x.cpu().float()*std + mean if do_x else x.cpu()

def _multi_normalize_batch(b:Tuple[Tensor,Tensor], mean:FloatTensor, std:FloatTensor, do_x:bool=True, do_y:bool=False)->Tuple[Tensor,Tensor]:
    "`b` = `x`,`y` - normalize `x` array of imgs and `do_y` optionally `y`."
    x,y = b
    mean,std = mean.to(x.device),std.to(x.device)
    if do_x: x = multi_normalize(x,mean,std)
    if do_y and len(y.shape) == 4: y = multi_normalize(y,mean,std)
    return x,y

def multi_normalize_funcs(mean:FloatTensor, std:FloatTensor, do_x:bool=True, do_y:bool=False)->Tuple[Callable,Callable]:
    "Create normalize/denormalize func using `mean` and `std`, can specify `do_y` and `device`."
    mean,std = tensor(mean),tensor(std)
    return (partial(_multi_normalize_batch, mean=mean, std=std, do_x=do_x, do_y=do_y),
            partial(multi_denormalize, mean=mean, std=std, do_x=do_x))
        
def multi_image_channel_view(x):
    n_chan = 1
    return x.transpose(0,1).contiguous().view(n_chan,-1)

    

class MultiImageDataBunch(ImageDataBunch):
    def batch_stats(self, funcs:Collection[Callable]=None, ds_type:DatasetType=DatasetType.Train)->Tensor:
        "Grab a batch of data and call reduction function `func` per channel"
        funcs = ifnone(funcs, [torch.mean,torch.std])
        x = self.one_batch(ds_type=ds_type, denorm=False)[0].cpu()
        
        return [func(multi_image_channel_view(x), 1) for func in funcs]

    def normalize(self, stats:Collection[Tensor]=None, do_x:bool=True, do_y:bool=False)->None:
        "Add normalize transform using `stats` (defaults to `DataBunch.batch_stats`)"
        if getattr(self,'norm',False): raise Exception('Can not call normalize twice')
        if stats is None: self.stats = self.batch_stats()
        else:             self.stats = stats
        self.norm,self.denorm = multi_normalize_funcs(*self.stats, do_x=do_x, do_y=do_y)
        self.add_tfm(self.norm)
        
        return self



class MultiImageList(ImageList):
    "`ItemList` suitable for computer vision."
    _bunch,_square_show,_square_show_res = MultiImageDataBunch,True,True
    def __init__(self, *args,  **kwargs):
        super().__init__(*args, **kwargs)
        self.channels = 1
        
    def open(self, fn):
        img_data = np.load(fn)
        
        img_list = []
        if len(img_data.shape) == 4:
            for j in range(img_data.shape[0]):
                for i in range(img_data.shape[1]):
                    img_list.append(Image(tensor(img_data[j,i][None])))
        else:
            for i in range(img_data.shape[0]):
                img_list.append(Image(tensor(img_data[i][None])))
        
        self.channels = img_list[0].data.shape[0]
        return MultiImage(img_list)
        
    def reconstruct(self, t:Tensor):
        #set_trace()
        n, h, w = t.shape
        n //= self.channels
        one_img = t.view(self.channels,n*h,w)
        return Image(one_img.float().clamp(min=0,max=1))

class NpyRawImageList(ImageList):
    def open(self, fn):
        img_data = np.load(fn)
        return Image(tensor(img_data[None]))

    def analyze_pred(self, pred):
        return pred[0:1]

    def reconstruct(self, t):
        return Image(t.float().clamp(min=0,max=1))                

class MultiImageImageList(MultiImageList):
    _label_cls,_square_show,_square_show_res = NpyRawImageList,False,False

    def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
        "Show the `xs` (inputs) and `ys`(targets)  on a figure of `figsize`."
        axs = subplots(len(xs), 2, imgsize=imgsize, figsize=figsize)
        for i, (x,y) in enumerate(zip(xs,ys)):
            x.show(ax=axs[i,0], **kwargs)
            y.show(ax=axs[i,1], **kwargs)
        plt.tight_layout()

    def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
        "Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`."
        title = 'Input / Prediction / Target'
        axs = subplots(len(xs), 3, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=14)
        for i,(x,y,z) in enumerate(zip(xs,ys,zs)):
            x.show(ax=axs[i,0], **kwargs)
            y.show(ax=axs[i,2], **kwargs)
            z.show(ax=axs[i,1], **kwargs)

In [None]:
def map_to_hr(x):
    hr_name = x.relative_to(lr_path)
    return hr_path/hr_name

def get_src(size=128, scale=1):
    if scale != 1: use_lr_path = lr_path
    else: use_lr_path = lr_path
        
    def map_to_hr(x):
        hr_name = x.relative_to(use_lr_path)
        return hr_path/hr_name
    
    src = (MultiImageImageList
            .from_folder(lr_path, extensions=['.npy'])
            .split_by_folder()
            .label_from_func(map_to_hr))
    return src

def get_data(bs, size, scale=1, max_zoom=1.):
    src = get_src(size, scale=scale)    
    tfms = get_transforms(flip_vert=True, max_zoom=max_zoom)
    #tfms = [[],[]]
    data = (src
            .transform(tfms, size=size)
            .transform_y(tfms, size=size*scale)
            .databunch(bs=bs).normalize(do_y=True))
    return data

In [None]:
torch.cuda.device_count()

In [None]:
bs = 8
size = 64
data = get_data(bs, size, scale=4, max_zoom=1.2)

In [None]:
%debug

In [None]:
class TwoXModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        num_img = x.shape[1] // 2
        x1, x2 = x[:,0:num_img], x[:,num_img:]
        y1 = self.model(x1)
        y2 = self.model(x2)
        return torch.cat((y1,y2),1)
    
    
class TwoYLoss(nn.Module):
    def __init__(self, base_loss=F.mse_loss, stable_wt=0.15):
        super().__init__()
        self.base_loss = base_loss
        self.stable_wt = stable_wt
        self.base_loss_wt = (1-stable_wt)/2
        self.metric_names = ['pixel','stable','ssim','psnr']

    def forward(self, input, target):
        base_loss = self.base_loss
        y1 = input[:,0:1,:,:]
        y2 = input[:,1:2,:,:]
        base_1 = base_loss(y1, target)
        base_2 = base_loss(y2, target)
        stable_err = F.mse_loss(y1,y2)
        loss = (base_1 * self.base_loss_wt +
                base_2 * self.base_loss_wt +
                stable_err * self.stable_wt)
        self.metrics = {
            'pixel': (base_1+base_2)/2,
            'stable': stable_err,
            'ssim': (ssim.ssim(y1, target)+ssim.ssim(y2,target))/2,
            'psnr': (psnr(y1, target)+psnr(y2,target))/2
        }
        return loss
    


In [None]:
bs = 8
size = 64
data = get_data(bs, size, scale=4, max_zoom=1.2)

in_c = 5
out_c = 1
nf = 16
gcval = 16 # 32
nb = 5 # 23

#loss = TwoYLoss(F.mse_loss)
loss = F.mse_loss

#model = TwoXModel(RRDB_Net(in_c, out_c, nf, nb, gc=gcval))
model = RRDB_Net(in_c, out_c, nf, nb, gc=gcval)
model = nn.DataParallel(model)
#opt_func = partial(adabound.AdaBound,  lr=1e-6)

In [None]:
x,y = data.one_batch()

In [None]:
x.shape, y.shape

In [None]:
x.min(), x.max(), x.mean(), x.std()

In [None]:
y.min(), y.max(), y.mean(), y.std()

In [None]:
#learn = Learner(data, model, callback_fns=LossMetrics, loss_func=loss) #, opt_func=opt_func)
learn = Learner(data, model, loss_func=loss) #, opt_func=opt_func)

learn = learn.to_fp16(loss_scale=64)
gc.collect()

In [None]:
learn.data.show_batch(3)

In [None]:
#learn.lr_find()
#learn.recorder.plot()

In [None]:
learn.fit_one_cycle(cyc_len=1, max_lr=1e-3)
#learn.fit(4, m)

In [None]:
learn.save('multiimg.1')

In [None]:
learn.show_results(rows=3)

In [None]:
x,y = data.one_batch()
x.shape, y.shape

In [None]:
bs = 4
size = 256
data = get_data(bs, size, scale=4, max_zoom=1.2)
in_c = 5
out_c = 1
nf = gcval = 16 # 32
nb = 5 # 23

model = TwoXModel(RRDB_Net(in_c, out_c, nf, nb, gc=gcval))
model = RRDB_Net(in_c, out_c, nf, nb, gc=gcval)
model = nn.DataParallel(model)
learn = Learner(data, model, loss_func=loss) #, opt_func=opt_func)
learn = learn.to_fp16()
learn = learn.load('multiimg.1')

In [None]:
learn.fit_one_cycle(2, max_lr=1e-3)

In [None]:
learn.save('multiimg.2')

In [None]:
learn.show_results(rows=3, imgsize=5)

In [None]:
movie_files = []
#movie_files = list(Path('/scratch/bpho/datasets/movies_001/test').glob('*.czi'))
#movie_files += list(Path('/scratch/bpho/datasources/low_res_test/').glob('low res confocal*.czi'))
movie_files += list(Path('/scratch/bpho/datasources/neuron_movies2/').glob('low*.*'))

In [None]:
#movie_files = movie_files[0:2]

In [None]:
def image_from_tiles(learn, in_img, tile_sz=128, scale=4, wsize=3):
    cur_size = in_img.shape[1:3]
    c = in_img.shape[0]
    new_size = (cur_size[0]*scale, cur_size[1]*scale)
    w, h = cur_size
    
    in_tile = torch.zeros((c,tile_sz,tile_sz))
    out_img = torch.zeros((1,w*scale,h*scale))
    
    for x_tile in range(math.ceil(w/tile_sz)):
        for y_tile in range(math.ceil(h/tile_sz)):
            x_start = x_tile

            x_start = x_tile*tile_sz
            x_end = min(x_start+tile_sz, w)
            y_start = y_tile*tile_sz
            y_end = min(y_start+tile_sz, h)
            
            
            in_tile[:,0:(x_end-x_start), 0:(y_end-y_start)] = tensor(in_img[:,x_start:x_end, y_start:y_end])
            
            img_list = [Image(in_tile[i][None]) for i in range(wsize)]
            #img_list += img_list
            
            tlist = MultiImage(img_list)
            
            out_tile,_,_ = learn.predict(tlist)
            
            out_x_start = x_start * scale
            out_x_end = x_end * scale
            out_y_start = y_start * scale
            out_y_end = y_end * scale

            #print("out: ", out_x_start, out_y_start, ",", out_x_end, out_y_end)
            in_x_start = 0
            in_y_start = 0
            in_x_end = (x_end-x_start) * scale
            in_y_end = (y_end-y_start) * scale
            #print("tile: ",in_x_start, in_y_start, ",", in_x_end, in_y_end)
           
            out_img[:,out_x_start:out_x_end, out_y_start:out_y_end] = out_tile.data[:,
                                                                                  in_x_start:in_x_end, 
                                                                                  in_y_start:in_y_end]
    return out_img

In [1]:
from skimage.io import imsave

def tif_predict_movie(learn, tif_in, orig_out='orig.tif', pred_out='pred.tif', size=128, wsize=3):
    im = PIL.Image.open(tif_in)
    im.load()
    times = im.n_frames
    #times = min(times,100)
    imgs = []
    for i in range(times):
        im.seek(i)
        imgs.append(np.array(im).astype(np.float32)/255.)
    img_data = np.stack(imgs)
    
    def pull_frame(i):
        im.seek(i)
        im.load()
        return np.array(im)
    
    folder = Path(pred_out).stem().mkdir()
    
    preds = []
    origs = []
    img_max = img_data.max()
    print(img_max)
    print('max: ', img_max)
    for t in progress_bar(list(range(0,times-wsize+1))):
        img = img_data[t:(t+wsize)].copy()
        img /= img_max
        
        out_img = image_from_tiles(learn, img, tile_sz=size, wsize=wsize)
        pred = (out_img*255).cpu().numpy().astype(np.uint8)
        preds.append(pred)
        orig = (img[1][None]*255).astype(np.uint8)
        origs.append(orig)

    all_y = np.concatenate(preds)
    #print(all_y.shape)
    imageio.mimwrite(pred_out, all_y, imagej=True) #, fps=30, macro_block_size=None) # for mp4
    all_y = np.concatenate(origs)
    #print(all_y.shape)
    imageio.mimwrite(orig_out, all_y, imagej=True) #, fps=30, macro_block_size=None)


def czi_predict_movie(learn, czi_in, orig_out='orig.tif', pred_out='pred.tif', size=128, wsize=3):
    with czifile.CziFile(czi_in) as czi_f:
        proc_axes, proc_shape = get_czi_shape_info(czi_f)
        channels = proc_shape['C']
        depths = proc_shape['Z']
        times = proc_shape['T']
        #times = min(times, 100)
        x,y = proc_shape['X'], proc_shape['Y']
        print(f'czi: x:{x} y:{y} t:{times} z:{depths}')
        
        #folder_name = Path(pred_out).stem
        #folder = Path(folder_name)
        #if folder.exists(): shutil.rmtree(folder)
        #folder.mkdir()
        
        data = czi_f.asarray().astype(np.float32)/255.
        preds = []
        origs = []
        
        img_max = data.max()
        print(img_max)
        for t in progress_bar(list(range(0,times-wsize+1))):
            idx = build_index(proc_axes, {'T': slice(t,t+wsize), 'C': 0, 'Z':0, 'X':slice(0,x),'Y':slice(0,y)})
            img = data[idx].copy()
            img /= img_max
            
            out_img = image_from_tiles(learn, img, tile_sz=size, wsize=wsize)
            pred = (out_img*255).cpu().numpy().astype(np.uint8)
            preds.append(pred)
            #imsave(folder/f'{t}.tif', pred[0])
            
            orig = (img[1][None]*255).astype(np.uint8)
            origs.append(orig)
            
        all_y = np.concatenate(preds)
        #print(all_y.shape)
        imageio.mimwrite(pred_out, all_y) #, fps=30, macro_block_size=None) # for mp4
        all_y = np.concatenate(origs)
        #print(all_y.shape)
        imageio.mimwrite(orig_out, all_y) #, fps=30, macro_block_size=None)
        

  data = yaml.load(f.read()) or {}


In [None]:
bs = 2
size = 440
data = get_data(bs, size, scale=4, max_zoom=1.2)
in_c = 5
out_c = 1
nf = gcval = 16 # 32
nb = 5 # 23

#model = TwoXModel(RRDB_Net(in_c, out_c, nf, nb, gc=gcval))
model = RRDB_Net(in_c, out_c, nf, nb, gc=gcval)
loss = F.mse_loss

model = nn.DataParallel(model)
learn = Learner(data, model, callback_fns=LossMetrics, loss_func=loss) #, opt_func=opt_func)
learn = learn.to_fp16()
learn = learn.load('multiimg.2')

gc.collect()

In [None]:
x,y = data.one_batch()
print(x.shape, y.shape)

In [None]:
for fn in progress_bar(movie_files):
    #try:
        pred_name = f'{fn.stem}_pred.tif'
        orig_name = f'{fn.stem}_orig.tif'
        if not Path(pred_name).exists():
            if fn.suffix == '.czi':
                print(f'czi {fn.stem}')
                czi_predict_movie(learn, fn, size=size, orig_out=orig_name, pred_out=pred_name, wsize=5)
            elif fn.suffix == '.tif':
                tif_predict_movie(learn, fn, size=size, orig_out=orig_name, pred_out=pred_name, wsize=5)
                tif_fn = fn
                print(f'tif {fn.stem}')
        else:
            print(f'skip: {fn.stem}')
    #except:
    #    print(f'err: {fn.stem}')

In [None]:
#%debug

In [None]:
im = PIL.Image.open(tif_fn)
im.load()
im.n_frames

In [None]:
im.seek(3)
np.array(im).shape

In [None]:
images = list(PIL.ImageSequence.Iterator(im))

In [None]:
w_imgs = images[0:3]

In [None]:
img = np.stack([np.array(img) for img in w_imgs]).shape

In [None]:
w_imgs[2]

In [None]:
fn

In [None]:
!ls -hatlr *.tif

In [None]:
%debug

In [None]:
x.shape

In [None]:
doc(conv_layer)

In [None]:
%debug