In [1]:
#import fastai
from fastai import *          # Quick access to most common functionality
from fastai.vision import *   # Quick access to computer vision functionality
from fastai.callbacks import *

In [2]:
import pytorch_ssim as ssim
from superres import *
from torchvision.models import vgg16_bn
import czifile

In [17]:
path = Path('/DATA/WAMRI/salk/uri/BPHO/')
path_processed = path/'processed'
model_dir = path/'models'
path_hr = path/'hires'
path_mr = path/'midres'
path_lr = path/'lores'
path_test = path/'test'

path_hr.mkdir(exist_ok=True)
path_mr.mkdir(exist_ok=True)
path_lr.mkdir(exist_ok=True)
path_test.mkdir(exist_ok=True)

In [18]:
def get_czi_shape_info(czi):
    shape = czi.shape
    axes = czi.axes
    axes_dict = {axis:idx for idx,axis in enumerate(czi.axes)}
    shape_dict = {axis:shape[axes_dict[axis]] for axis in czi.axes}
    return axes_dict, shape_dict


def build_index(axes, ix_select):
    idx = [ix_select.get(ax, 0) for ax in axes]
    return tuple(idx)


def process_czi(proc_fn):
    with czifile.CziFile(proc_fn) as proc_czf:
        proc_axes, proc_shape = get_czi_shape_info(proc_czf)
        channels = proc_shape['C']
        depths = proc_shape['Z']
        x,y = proc_shape['X'], proc_shape['Y']
        data = proc_czf.asarray()
        for channel in range(channels):
            for depth in range(depths):
                idx = build_index(proc_axes, {'C': channel, 'Z':depth, 'X':slice(0,x),'Y':slice(0,y)})
                img = data[idx]
                save_proc_fn = path_hr/f'{proc_fn.stem}_{channel:02d}_{depth:03d}.npy'
                np.save(save_proc_fn, img)
        

In [19]:
#proc_fns = list(path_processed.glob('*.czi'))
#for fn in progress_bar(proc_fns):
#    process_czi(fn)

In [32]:
def resize_one(fn,i):
    dest = path_lr/fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    
    data = np.load(fn)
    data = data.astype(float) / 8000.0
    data *= 255
    data = data.astype(np.uint8)
    
    img = PIL.Image.fromarray(data, mode='L')
    targ_sz = resize_to(img,96,use_min=True)
    img = img.resize(targ_sz, resample=PIL.Image.BILINEAR)
    img.save(str(dest).replace('.npy','.jpg'), quality=100)
    
def resize_two(fn,i):
    dest = path_mr/fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    
    data = np.load(fn)
    data = data.astype(float) / 8000.0
    data *= 255
    data = data.astype(np.uint8)
    
    img = PIL.Image.fromarray(data, mode='L')
    targ_sz = resize_to(img,256,use_min=True)
    img = img.resize(targ_sz, resample=PIL.Image.BILINEAR)
    img.save(str(dest).replace('.npy','.jpg'), quality=100)

In [35]:
# hr_fns = list(path_hr.glob('*.npy'))
# parallel(resize_two, hr_fns)

In [22]:
# hr_fns = list(path_hr.glob('*.npy'))
# parallel(resize_one, hr_fns)

In [38]:
class ProcImageList(ImageImageList):
    def open(self, fn):
        data = np.load(fn)
        x = torch.from_numpy(data[None,:,:].astype(np.float32))
        x.div_(8000.)
        return Image(x.repeat([3,1,1]))


In [39]:
def get_basename(x):
    return x.stem.split('_')[0]

base_names = list(set([get_basename(x) for x in list(path_lr.iterdir())]))
train_names, valid_names = random_split(0.15, base_names)
valid_names = list(valid_names[0])

def is_validation_basename(x):
    xbase = get_basename(x)
    return xbase in valid_names

src = (ImageItemList
       .from_folder(path_lr, label_cls=ProcImageList, extensions=".jpg", mode='L')
       .split_by_valid_func(is_validation_basename))

src_mr = (ImageItemList
       .from_folder(path_mr, label_cls=ProcImageList, extensions=".jpg", mode='L')
       .split_by_valid_func(is_validation_basename))

def get_data(src,bs,size, **kwargs):
    def lr_to_hr_fn(x):
        x_hr = path_hr/str(x.stem + ".npy")
        return x_hr
    
    data = (src.label_from_func(lr_to_hr_fn)
            .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
            .databunch(bs=bs,**kwargs).normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data

In [42]:
arch = models.resnet34
bs,size = 4,512
data = get_data(src_mr, bs, size)

In [43]:
data.train_ds[0]

(Image (3, 512, 512), Image (3, 512, 512))

In [44]:
#data.show_batch(ds_type=DatasetType.Valid, rows=2, figsize=(9,9))

In [45]:
def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

In [46]:
base_loss = F.l1_loss

In [47]:
vgg_m = vgg16_bn(True).features.cuda().eval()
requires_grad(vgg_m, False)

In [48]:
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]
blocks, [vgg_m[i] for i in blocks]

([5, 12, 22, 32, 42],
 [ReLU(inplace), ReLU(inplace), ReLU(inplace), ReLU(inplace), ReLU(inplace)])

In [49]:
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()

In [50]:
feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])

In [51]:
wd = 1e-3
learn = unet_learner(data, arch, wd=wd, loss_func=feat_loss, 
                     callback_fns=LossMetrics, blur=True, norm_type=NormType.Weight,
                     metrics=superres_metrics, model_dir=model_dir)
gc.collect();

In [None]:
# learn.lr_find()
# learn.recorder.plot()

In [52]:
lr = 1e-2
def do_fit(save_name, lrs=slice(lr), pct_start=0.9):
    learn.fit_one_cycle(10, lrs, pct_start=pct_start)
    learn.save(save_name)
    learn.show_results(rows=1, imgsize=5)
    learn.recorder.plot_losses()

In [None]:
do_fit('1a', slice(lr*10))

epoch,train_loss,valid_loss,mse_loss,ssim,psnr,pixel,feat_0,feat_1,feat_2,gram_0,gram_1,gram_2
1,2.973625,2.588928,1.067484,0.517662,-0.278924,0.922366,0.201913,0.212321,0.011464,0.701248,0.536879,0.002739
2,1.203431,0.890976,0.007507,0.787360,23.374224,0.035363,0.152004,0.165918,0.009526,0.260825,0.265155,0.002185
3,0.850613,0.911786,0.005997,0.873294,24.903582,0.029191,0.143298,0.151363,0.007771,0.348884,0.229287,0.001992
,,,,,,,,,,,,


In [None]:
learn.unfreeze()

In [None]:
do_fit('1b', slice(1e-4,1e-2))

In [None]:
data = get_data(src, bs//2,size*2)
learn.data = data
learn.freeze()

In [None]:
do_fit('2a')

In [None]:
learn.unfreeze()

In [None]:
do_fit('2b', slice(1e-5,1e-3), pct_start=0.3)

In [None]:
bs,size = 2,512
data = get_data(src_mr, bs, size)
learn = learn.load('2b')
do_fit('3', slice(1e-4,1e-2), pct_start=0.3)

In [None]:
learn.data = data
tst_imgs  = list(Path('/DATA/WAMRI/salk/uri/Image_restoration_data/newimg/').iterdir())

In [None]:
fn = tst_imgs[1]
img = open_grayscale(fn)

a,b,c = learn.predict(img)

In [None]:
img

In [None]:
agray = Image(a.data[0:1,:,:])

In [None]:
agray.size

In [None]:
agray

In [None]:
ssim.ssim(agray.data[None], img.data[None])