In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

## Style transfer net

In [None]:
from fastai.conv_learner import *
from pathlib import Path
torch.cuda.set_device(0)

import cv2

torch.backends.cudnn.benchmark=True

In [None]:
PATH = Path('data')

In [None]:
fnames_full,label_arr_full,all_labels = folder_source(PATH, 'train/imagenet')
fnames_full = ['/'.join(Path(fn).parts[-2:]) for fn in fnames_full]
list(zip(fnames_full[:5],label_arr_full[:5]))

In [None]:
all_labels[:5]

In [None]:
np.random.seed(42)
keep_pct = 0.1
keeps = np.random.rand(len(fnames_full)) < keep_pct
fnames = np.array(fnames_full, copy=False)[keeps]
label_arr = np.array(label_arr_full, copy=False)[keeps]

## test case

In [None]:
PATH = 'data/train/lena/'
PATH_TRN = f'{PATH}converted'
os.listdir(PATH)[:4]

In [None]:
arch = vgg16
# sz,bs = 96,32
MAX_SZ,bs = 512,1
# sz,bs = 128,32

In [None]:
def scale_match(h, w, targ):
    sh,sw,_ = targ.shape
    rat = max(h/sh,w/sw); rat
    res = cv2.resize(targ, (int(sw*rat), int(sh*rat)))
    return res[:h,:w]

In [None]:
def get_shape_mean(path):
    sumx = 0
    sumy = 0
    count = 0
    
    for fname in os.listdir(path):
        if fname[-4:] == 'jpeg':
            style_img = open_image(path + fname)
            style_img
            sumx += style_img.shape[0]
            sumy += style_img.shape[1]
            count += 1

    return sumx/count, sumy/count 

In [None]:
def set_shape_mean(path):
    mean_h, mean_w = get_shape_mean(path) #224,224

    for fname in os.listdir(path):
        if fname[-4:] == 'jpeg':
            style_img = open_image(path + fname)
            size = min(int(mean_h), int(mean_w))
            size = size if size <= MAX_SZ else MAX_SZ
            style = scale_match(size, size, style_img)
            cv2.imwrite(f'{path}/converted/converted_{fname}',cv2.cvtColor(style[:, :, :]*255, cv2.COLOR_BGR2RGB))

In [None]:
set_shape_mean(PATH)

In [None]:
fnames = np.array(os.listdir('data/train/lena/converted'))
fnames[:5]

In [None]:
label_arr = np.array(list(range(0, len(fnames))))
label_arr[:5]

In [None]:
keeps = np.array(np.full((len(label_arr)), 1))
keeps

In [None]:
class MatchedFilesDataset(FilesDataset):
    def __init__(self, fnames, y, transform, path):
        self.y=y
        assert(len(fnames)==len(y))
        super().__init__(fnames, transform, path)
    def get_y(self, i): return open_image(os.path.join(self.path, self.y[i]))
    def get_c(self): return 0

In [None]:
val_idxs = get_cv_idxs(len(fnames), val_pct=0.2) #min(0.01/keep_pct, 0.1)
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, np.array(fnames), np.array(fnames))
len(val_x),len(trn_x)

In [None]:
tfms = tfms_from_model(arch, MAX_SZ, tfm_y=TfmType.PIXEL)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH_TRN)
md = ImageData(PATH, datasets, bs, num_workers=8, classes=None)

In [None]:
denorm = md.val_ds.denorm

In [None]:
def show_img(ims, idx, figsize=(5,5), normed=True, ax=None):
    if ax is None: fig,ax = plt.subplots(figsize=figsize)
    if normed: ims = denorm(ims)
    else:      ims = np.rollaxis(to_np(ims),1,4)
    ax.imshow(np.clip(ims,0,1)[idx])
    ax.axis('off')

## Model

In [None]:
def conv(ni, nf, kernel_size=3, stride=1, actn=True, pad=None, bn=True):
    if pad is None: pad = kernel_size//2
    layers = [nn.Conv2d(ni, nf, kernel_size, stride=stride, padding=pad, bias=not bn)]
    if actn: layers.append(nn.ReLU(inplace=True))
    if bn: layers.append(nn.BatchNorm2d(nf))
    return nn.Sequential(*layers)

In [None]:
class ResSequentialCenter(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.m = nn.Sequential(*layers)

    def forward(self, x): return x[:, :, 2:-2, 2:-2] + self.m(x)

In [None]:
def res_block(nf):
    return ResSequentialCenter([conv(nf, nf, actn=True, pad=0), conv(nf, nf, pad=0)])

In [None]:
def upsample(ni, nf):
    return nn.Sequential(nn.Upsample(scale_factor=2), conv(ni, nf))

In [None]:
class StyleResnet(nn.Module):
    def __init__(self):
        super().__init__()
        features = [nn.ReflectionPad2d(40),
                    conv(3, 32, 9),
                    conv(32, 64, stride=2), conv(64, 128, stride=2)]
        for i in range(5): features.append(res_block(128))
        features += [upsample(128, 64), upsample(64, 32),
                     conv(32, 3, 9, actn=False)]
        self.features = nn.Sequential(*features)
        
    def forward(self, x): return self.features(x)

## Style Image

In [None]:
style_fn = 'data/style/impressionism/original/23. signac-saint-tropez-fontaine-des-lices-1895-copy.jpeg'
style_img = open_image(style_fn)
style_img.shape

In [None]:
plt.imshow(style_img);

In [None]:
h,w,_ = style_img.shape
rat = max(MAX_SZ/h,MAX_SZ/h)
res = cv2.resize(style_img, (int(w*rat), int(h*rat)), interpolation=cv2.INTER_AREA)
resz_style = res[:MAX_SZ,-MAX_SZ:]

In [None]:
plt.imshow(resz_style);

In [None]:
style_tfm,_ = tfms[1](resz_style,resz_style)

In [None]:
style_tfm = np.broadcast_to(style_tfm[None], (bs,)+style_tfm.shape)

In [None]:
style_tfm.shape

## Perceptual loss

In [None]:
m_vgg = vgg16(True)

In [None]:
blocks = [i-1 for i,o in enumerate(children(m_vgg))
              if isinstance(o,nn.MaxPool2d)]
blocks, [m_vgg[i] for i in blocks[1:]]

In [None]:
vgg_layers = children(m_vgg)[:43]
m_vgg = nn.Sequential(*vgg_layers).cuda().eval()
set_trainable(m_vgg, False)

In [None]:
def flatten(x): return x.view(x.size(0), -1)

In [None]:
class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()        

In [None]:
def ct_loss(input, target): return F.mse_loss(input,target)

def gram(input):
        b,c,h,w = input.size()
        x = input.view(b, c, -1)
        return torch.bmm(x, x.transpose(1,2))/(c*h*w)*1e6

def gram_loss(input, target):
    return F.mse_loss(gram(input), gram(target[:input.size(0)]))

In [None]:
class CombinedLoss(nn.Module):
    def __init__(self, m, layer_ids, style_im, ct_wgt, style_wgts):
        super().__init__()
        self.m,self.ct_wgt,self.style_wgts = m,ct_wgt,style_wgts
        self.sfs = [SaveFeatures(m[i]) for i in layer_ids]
        m(VV(style_im))
        self.style_feat = [V(o.features.data.clone()) for o in self.sfs]

    def forward(self, input, target, sum_layers=True):
        self.m(VV(target.data))
        targ_feat = self.sfs[2].features.data.clone()
        self.m(input)
        inp_feat = [o.features for o in self.sfs]
        
        res = [ct_loss(inp_feat[2],V(targ_feat)) * self.ct_wgt]
        res += [gram_loss(inp,targ)*wgt for inp,targ,wgt
                in zip(inp_feat, self.style_feat, self.style_wgts)]
        
        if sum_layers: res = sum(res)
        return res
    
    def close(self):
        for o in self.sfs: o.remove()

In [None]:
m = StyleResnet()
m = to_gpu(m)

In [None]:
learn = Learner(md, SingleModel(m), opt_fn=optim.Adam)

In [None]:
learn.crit = CombinedLoss(m_vgg, blocks[1:], style_tfm, 1e5, [0.025,0.275,5.,0.2])

In [None]:
wd=1e-7

In [None]:
learn.lr_find() #wds=wd
learn.sched.plot() #n_skip_end=1

In [None]:
lr=1e-3

In [None]:
learn.fit(lr, 5, cycle_len=1) #, wds=wd, use_clr=(20,10)

In [None]:
learn.save('impressionism_20e')

In [None]:
learn.fit(lr, 5, cycle_len=2, wds=wd, use_clr=(20,10))

In [None]:
learn.save('impressionism_20e10ER')

In [None]:
learn.fit(lr, 3, cycle_len=1, cycle_mult=2, wds=wd, use_clr=(20,10))

In [None]:
learn.save('impressionism_20e10er7erm')

In [None]:
learn.fit(lr, 50, cycle_len=1, cycle_mult=1, wds=wd, use_clr=(20,10))

## Look at the image and save it

In [None]:
x,y=md.val_ds[len(val_x)-12]

In [None]:
learn.model.eval()
preds = learn.model(VV(x[None]))
x.shape,y.shape,preds.shape

In [None]:
learn.crit(preds, VV(y[None]), sum_layers=False)

In [None]:
#learn.crit(preds, VV(y[None]), sum_layers=False)

In [None]:
#learn.crit.close()

In [None]:
_,axes=plt.subplots(1,2,figsize=(14,7))
show_img(x[None], 0, ax=axes[0])
show_img(preds, 0, ax=axes[1])

In [None]:
im = np.clip(denorm(preds),0,1)[0]

In [None]:
cv2.imwrite(f'{PATH}/result.jpeg', im*255)