In [6]:
import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.script import *
from skimage import filters
from skimage.util import random_noise

# loading data

In [7]:
from data.load import get_data

data_path = Path('/home/alaa/Dropbox (BPHO)/BPHO Staff/USF/')
lr_path = f'EM/training/trainsets/crappified/'
hr_path = f'EM/training/trainsets/hr/'

bs = 8
size = 128

db = get_data(data_pth=data_path, lr_dir=lr_path, hr_dir=hr_path,
             bs=bs, size=size)

db.show_batch()

In [None]:
model_path = data_pth/f'EM/models/'

In [None]:
torch.cuda.set_device(3)

# resnet.py

In [None]:
from functools import partial
import fastai.vision.learner as fvl

In [None]:
act_fn = nn.ReLU(inplace=True)
def init_cnn(m):
    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
    if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
    for l in m.children(): init_cnn(l)
def conv(ni, nf, ks=3, stride=1, bias=False):
    return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)

def noop(x): return x

def conv_layer(ni, nf, ks=3, stride=1, act=True):
    layers = [nn.utils.weight_norm(conv(ni, nf, ks, stride=stride))]
    if act: layers.append(act_fn)
    return nn.Sequential(*layers)

class ResBlock(nn.Module):
    def __init__(self, expansion, ni, nh, stride=1):
        super().__init__()
        nf,ni = nh*expansion,ni*expansion
        layers  = [conv_layer(ni, nh, 1)]
        layers += [
            conv_layer(nh, nf, 3, stride=stride,act=False)
        ] if expansion==1 else [
            conv_layer(nh, nh, 3, stride=stride),
            conv_layer(nh, nf, 1, act=False)
        ]
        self.convs = nn.Sequential(*layers)
        # TODO: check whether act=True works better
        self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False)
        self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)

    def forward(self, x): return act_fn(self.convs(x) + self.idconv(self.pool(x)))

def filt_sz(recep): return min(64, 2**math.floor(math.log2(recep*0.75)))

class WNResNet(nn.Sequential):
    def __init__(self, expansion, layers, c_in=3, c_out=1000):
        stem = []
        sizes = [c_in,32,32,64]
        for i in range(3):
            stem.append(conv_layer(sizes[i], sizes[i+1], stride=2 if i==0 else 1))
            #nf = filt_sz(c_in*9)
            #stem.append(conv_layer(c_in, nf, stride=2 if i==1 else 1))
            #c_in = nf

        block_szs = [64//expansion,64,128,256,512]
        blocks = [self._make_layer(expansion, block_szs[i], block_szs[i+1], l, 1 if i==0 else 2)
                  for i,l in enumerate(layers)]
        super().__init__(
            *stem,
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            *blocks,
            nn.AdaptiveAvgPool2d(1), Flatten(),
            nn.Linear(block_szs[-1]*expansion, c_out),
        )
        init_cnn(self)

    def _make_layer(self, expansion, ni, nf, blocks, stride):
        return nn.Sequential(
            *[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1)
              for i in range(blocks)])

def wnresnet(expansion, n_layers, name, pretrained=False, **kwargs):
    model = WNResNet(expansion, n_layers, **kwargs)
    return model

def _wnresnet_split(m:nn.Module): return (m[0][6],m[1])
_wnresnet_meta     = {'cut':-2, 'split':_wnresnet_split }

In [None]:
name = 'wnresnet34'
# setattr(me, name, partial(wnresnet, expansion=e, n_layers=l, name=name))
# arch = getattr(me, name)
# fvl.model_meta[arch] = {**_wnresnet_meta}

e, l = 1, [3,4,6 ,3]

arch = partial(wnresnet, expansion=e, n_layers=l, name=name)
fvl.model_meta[arch] = {**_wnresnet_meta}

# unet.py

In [None]:
from fastai.vision.models.unet import DynamicUnet
from fastai.vision.learner import cnn_config

In [None]:
class BilinearWrapper(nn.Module):
    def __init__(self, model, scale=4, mode='bilinear'):
        super().__init__()
        self.model = model
        self.scale = scale
        self.mode = mode

    def forward(self, x):
        return self.model(F.interpolate(x, scale_factor=self.scale, mode=self.mode, align_corners=False))

In [None]:
def wnres_unet_model(in_c, out_c, arch, blur=True, blur_final=True, self_attention=True, last_cross=True, bottle=True, norm_type=NormType.Weight, **wnres_args):
    meta = cnn_config(arch)
    enc_model = arch(c_in=in_c)
    cut = cnn_config(arch)['cut']
    body = nn.Sequential(*list(enc_model.children())[:cut])

    model = DynamicUnet(body,
                        n_classes=out_c,
                        blur=blur,
                        blur_final=blur_final,
                        self_attention=self_attention,
                        norm_type=norm_type,
                        last_cross=last_cross,
                        bottle=bottle, **wnres_args)
    return model, meta


def wnres_unet_learner(data, arch, in_c=1, out_c=1, wnres_args=None, bilinear_upsample=True, **kwargs):
    if wnres_args is None: wnres_args = {}
    model, meta = wnres_unet_model(in_c, out_c, arch, **wnres_args)
    learn = Learner(data, model, **kwargs)
    learn.split(meta['split'])
    apply_init(model[2], nn.init.kaiming_normal_)
    if bilinear_upsample:
        learn.model = BilinearWrapper(learn.model)
    return learn

# train.py

In [None]:
import torch.nn.functional as F
loss = F.mse_loss 
from model.metrics import psnr, ssim
metrics = [loss, psnr, ssim]

In [None]:
wnres_args = {
            'blur': True,
            'blur_final': True,
            'bottle': True,
            'self_attention': True,
            'last_cross': True
        }
wd = 1e-3

In [None]:
model_dir = './model/test'

In [None]:
callback_fns = []
save_name = 'combo'
size = 256
callback_fns.append(partial(SaveModelCallback, name=f'{save_name}_best_{size}'))

In [None]:
learn = wnres_unet_learner(db, arch, wnres_args=wnres_args,
                                       path=Path('.'), loss_func=loss, metrics=metrics,
                                       model_dir=model_dir, callback_fns=callback_fns, wd=wd)

In [None]:
lr = slice(None, 1e-4, None)

In [None]:
learn.fit_one_cycle(1, lr)

In [None]:
learn.show_results()

In [None]:
learn.save('')