In [None]:
import fastai
from fastai.vision import *
from fastai.callbacks import *
from multiprocessing import Pool
import PIL
from torchvision.models import vgg16_bn

### Dataset

In [None]:
path = Path('trainDataset/train')
path_hr = path/'HR'
path_lr = path/'LR'
path_hr2 = path/'HR2'
path_lr2 = path/'LR2'
savePath = path

In [None]:
bs,size=6,256
src1 = ImageImageList.from_folder(path_lr).random_split_by_pct(0.1, seed=42)
src2 = ImageImageList.from_folder(path_lr2).random_split_by_pct(0.1, seed=42)

In [None]:
def get_data1(bs,size):
    data = (src1.label_from_func(lambda x: path_hr/x.name)
           .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
           .databunch(bs=bs).normalize(imagenet_stats, do_y=True))

    data.c = 3
    return data

def get_data2(bs,size):
    data = (src2.label_from_func(lambda x: path_hr2/x.name)
           .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
           .databunch(bs=bs).normalize(imagenet_stats, do_y=True))

    data.c = 3
    return data

In [None]:
data = get_data1(bs,size)

### Model i funkcja straty

In [None]:
def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

In [None]:
base_loss = F.l1_loss

In [None]:
vgg_m = vgg16_bn(True).features.cuda().eval()
#vgg_m = vgg16_bn(True).features.eval()
requires_grad(vgg_m, False)

In [None]:
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]
blocks, [vgg_m[i] for i in blocks]

In [None]:
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]

    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)

    def __del__(self): self.hooks.remove()

In [None]:
feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])

In [None]:
wd = 1e-3
arch = models.resnet34
# input, label = data

learn = unet_learner(data, arch, wd=wd, loss_func=feat_loss, callback_fns=LossMetrics,
                     blur=True, norm_type=NormType.Weight)
gc.collect()

In [None]:
lr = 1e-3

### Funkcja ucząca

In [None]:
def do_fit(save_name, lrs=slice(lr), pct_start=0.9):
    learn.fit_one_cycle(10, lrs, pct_start=pct_start)
    learn.save(save_name)
    learn.show_results(rows=1, imgsize=5)

### Uczenie sieci

In [None]:
do_fit(savePath/'1a', slice(lr*10))

In [None]:
learn.unfreeze()
do_fit(savePath/'1b', slice(1e-5,lr))

In [None]:
data = get_data1(bs,size)
learn.data = data
learn.freeze()
gc.collect()

In [None]:
learn.load('2b')
print('ok')

In [None]:
do_fit('2ca')

In [None]:
learn.load('2ca')
data = get_data2(bs,size)
learn.data = data
gc.collect()
do_fit('2cb')

In [None]:
learn.load('2cb')
data = get_data1(bs,size)
learn.data = data
gc.collect()
learn.unfreeze()
do_fit('2da', slice(1e-6,1e-4), pct_start=0.3)

In [None]:
learn.load('2da')
data = get_data2(bs,size)
learn.data = data
gc.collect()

do_fit('2db', slice(1e-6,1e-4), pct_start=0.3)

In [None]:
learn.load('2db')
learn.export('model.pkl')

model =learn.model

model.eval()
checkpoint = {'model': model,
          'state_dict': model.state_dict()}

torch.save(checkpoint, 'checkpoint2.pth')