In [1]:
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from superres import superres_metrics, ssim, psnr

datasetname = 'multiframe_001'
data_path = Path('/scratch/bpho')
datasets = data_path/'datasets'
datasources = data_path/'datasources'
dataset = datasets/datasetname

hr_path = dataset/'hr'
lr_path = dataset/'lr'
lr_up_path = dataset/'lr_up'

In [2]:
class NpyRawImageList(ImageList):
    def open(self, fn):
        img_data = np.load(fn)
        return Image(tensor(img_data[None]))

class TransformableLists(ItemBase):
    def __init__(self, img_lists):
        self.img_lists = img_lists
       
    def __repr__(self):
        return f'MultiImg: {len(self.img_lists)}'
    
    @property
    def data(self):
        img_data = torch.stack([torch.stack([img.data for img in img_list]) 
                                            for img_list in self.img_lists])
        data = tensor(img_data)
        return data
    
    def apply_tfms(self, tfms, **kwargs):
        first_time = True
        save_img_lists = []
        for img_list in self.img_lists:
            save_img_list = []
            for img in img_list:
                new_img = img.apply_tfms(tfms, do_resolve=first_time, **kwargs)
                first_time = False
                save_img_list.append(new_img)
            save_img_lists.append(save_img_list)
        self.img_lists = save_img_lists
        return self
    
class MultiImageDataBunch(ImageDataBunch):
    def batch_stats(self, funcs:Collection[Callable]=None)->Tensor:
        "Grab a batch of data and call reduction function `func` per channel"
        funcs = ifnone(funcs, [torch.mean,torch.std])
        ds_type = DatasetType.Valid if self.valid_dl else DatasetType.Train
        x = self.one_batch(ds_type=ds_type, denorm=False)[0].cpu()
        def multi_channel_view(x):
            return x.transpose(3,0).contiguous().view(x.shape[3],-1)
        return [func(multi_channel_view(x), 1) for func in funcs]

    def normalize(self, stats:Collection[Tensor]=None, do_x:bool=True, do_y:bool=False)->None:
        "Add normalize transform using `stats` (defaults to `DataBunch.batch_stats`)"
        if getattr(self,'norm',False): raise Exception('Can not call normalize twice')
        if stats is None: self.stats = self.batch_stats()
        else:             self.stats = stats
        self.norm,self.denorm = normalize_funcs(*self.stats, do_x=do_x, do_y=do_y)
        self.add_tfm(self.norm)
        return self
    
class MultiImageList(ItemList):
    _bunch = MultiImageDataBunch
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def open(self, fn):
        img_data = np.load(fn)
        img_lists = []
        for j in range(img_data.shape[0]):
            imgs = []
            for i in range(img_data.shape[1]):
                imgs.append(Image(tensor(img_data[j,i,:,:][None])))
            img_lists.append(imgs)
        return TransformableLists(img_lists)
    
    def get(self, i):
        fn = super().get(i)
        img_lists = self.open(fn)
        return img_lists    
    

class MultiToMultiImageList(MultiImageList):
    _label_cls = NpyRawImageList

In [3]:
def map_to_hr(x):
    hr_name = x.relative_to(lr_path)
    return hr_path/hr_name

def get_src(size=128, scale=1):
    if scale != 1: use_lr_path = lr_path
    else: use_lr_path = lr_up_path
        
    def map_to_hr(x):
        hr_name = x.relative_to(use_lr_path)
        return hr_path/hr_name
    
    src = (MultiToMultiImageList.from_folder(use_lr_path, extensions=['.npy'])
           .split_by_folder()
           .label_from_func(map_to_hr))
    return src

def get_data(bs, size, scale=1, max_zoom=1.):
    src = get_src(size, scale=scale)    
    tfms = get_transforms(flip_vert=True, max_zoom=max_zoom)
    data = (src
            .transform(tfms, size=size)
            .transform_y(tfms, size=size*scale)
            .databunch(bs=bs).normalize(do_y=True))
    return data

In [4]:
print('cool')

cool


In [5]:
class TwoXModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, x):
        x1, x2 = x[:,0], x[:,1]
        y1 = self.model(x1)
        y2 = self.model(x2)
        return torch.stack([y1,y2])

In [11]:
class ShortcutBlock(nn.Module):
    #Elementwise sum the output of a submodule to its input
    def __init__(self, submodule):
        super(ShortcutBlock, self).__init__()
        self.sub = submodule

    def forward(self, x):
        output = x + self.sub(x)
        return output

    def __repr__(self):
        tmpstr = 'Identity + \n|'
        modstr = self.sub.__repr__().replace('\n', '\n|')
        tmpstr = tmpstr + modstr
        return tmpstr


def sequential(*args):
    # Flatten Sequential. It unwraps nn.Sequential.
    if len(args) == 1:
        if isinstance(args[0], OrderedDict):
            raise NotImplementedError('sequential does not support OrderedDict input.')
        return args[0]  # No sequential is needed.
    modules = []
    for module in args:
        if isinstance(module, nn.Sequential):
            for submodule in module.children():
                modules.append(submodule)
        elif isinstance(module, nn.Module):
            modules.append(module)
    return nn.Sequential(*modules)


class ResidualDenseBlock_5C(nn.Module):
    def __init__(self, nc, gc=32):
        super().__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 = conv_layer(nc, gc, norm_type=NormType.Weight, leaky=0.2)
        self.conv2 = conv_layer(nc+gc, gc, norm_type=NormType.Weight, leaky=0.2)
        self.conv3 = conv_layer(nc+2*gc, gc, norm_type=NormType.Weight, leaky=0.2)
        self.conv4 = conv_layer(nc+3*gc, gc, norm_type=NormType.Weight, leaky=0.2)
        # turn off activation?
        self.conv5 = conv_layer(nc+4*gc, gc, norm_type=NormType.Weight, leaky=0.2)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(torch.cat((x, x1), 1))
        x3 = self.conv3(torch.cat((x, x1, x2), 1))
        x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5.mul(0.2) + x
    
class RRDB(nn.Module):
    def __init__(self, nc, gc=32):
        super(RRDB, self).__init__()
        self.RDB1 = ResidualDenseBlock_5C(nc, gc)
        self.RDB2 = ResidualDenseBlock_5C(nc, gc)
        self.RDB3 = ResidualDenseBlock_5C(nc, gc)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out.mul(0.2) + x
    
class RRDB_Net(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4):
        super(RRDB_Net, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1
        
        fea_conv = conv_layer(in_nc, nf, norm_type=None, use_activ=False)
        rb_blocks = [RRDB(nf, gc=32) for _ in range(nb)]
        LR_conv = conv_layer(nf, nf, leaky=0.2)
        
        if upscale == 3:
            upsampler = PixelShuffle_ICNR(nf, blur=True, leaky=0.2, scale=3)
        else:
            upsampler = [PixelShuffle_ICNR(nf, blur=True, leaky=0.2) for _ in range(n_upscale)]


            
        HR_conv0 = conv_layer(nf, nf, leaky=0.2)
        HR_conv1 = conv_layer(nf, out_nc, norm_type=None, use_activ=False)

        self.model = sequential(
            fea_conv, 
            ShortcutBlock(sequential(*rb_blocks, LR_conv)),\
            *upsampler, HR_conv0, HR_conv1
        )

    def forward(self, x):
        x = self.model(x)
        return x

class MultiImageToMultiChannel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, x):
        new_x = x.view(x.shape[0],-1,x.shape[-2], x.shape[-1])
        return self.model(new_x)
    
class TwoYLoss(nn.Module):
    def __init__(self, base_loss=F.mse_loss, stable_wt=0.15):
        super().__init__()
        self.base_loss = base_loss
        self.stable_wt = stable_wt
        self.base_loss_wt = (1-stable_wt)/2
        self.metric_names = ['pixel_1','pixel_2','stable','ssim','psnr']
        

    def forward(self, input, target):
        base_loss = self.base_loss
        y1,y2 = input[0], input[1]
        
        base_1 = base_loss(y1, target)
        base_2 = base_loss(y2, target)
        stable_err = F.mse_loss(y1,y2)
        loss = (base_1 * self.base_loss_wt +
                base_2 * self.base_loss_wt +
                stable_err * self.stable_wt)
        self.metrics = {
            'pixel_1': base_1,
            'pixel_2': base_2,
            'stable': stable_err,
            'ssim': (ssim.ssim(y1, target)+ssim.ssim(y2,target))/2,
            'psnr': (psnr(y1, target)+psnr(y2,target))/2
        }
        return loss

In [12]:
bs = 8
size = 128
data = get_data(bs, size, scale=4)

In [13]:
in_c = 3
out_c = 1
nf = 32
nb = 5
loss = TwoYLoss()
model = TwoXModel(MultiImageToMultiChannel(RRDB_Net(3, 1, nf, nb )))
learn = Learner(data, model, callback_fns=LossMetrics, loss_func=loss)

In [None]:
learn.lr_find()
learn.recorder.plot()

In [14]:
learn.fit_one_cycle(1, max_lr=5e-3)

epoch,train_loss,valid_loss,pixel_1,pixel_2,stable,ssim,psnr,time
0,0.152352,0.13733,0.150685,0.150756,0.061452,0.209416,8.265573,01:48


In [None]:
learn.fit_one_cycle(10, max_lr=1e-3)

In [None]:

    
    
loss = TwoYLoss()    

In [None]:
doc(learn.fit_one_cycle)

In [None]:
ssim.ssim()