# Fast Style Tranfer

## Includes

In [None]:
import os, sys
import ipdb
import torch as t
import torchvision as tv
import torchnet as tnt
from torch.nn import functional as fun
from tqdm import tqdm_notebook as tqdm

# add paths for all sub-folders
paths = [root for root, dirs, files in os.walk(".")]

for item in paths:
    sys.path.append(item)

from ipynb.fs.full.config import Config
from ipynb.fs.full.module import BasicModule
from ipynb.fs.full.monitor import Visualizer
from ipynb.fs.full.network import LossNet, TransformNet
from ipynb.fs.full.util import *

## Initialization

In [None]:
# enable debug
%pdb off

opt = Config()
vis = Visualizer()
device = t.device("cuda" if t.cuda.is_available() else "cpu")

# load data
transforms = tv.transforms.Compose([
    tv.transforms.Resize(opt.img_size),
    tv.transforms.CenterCrop(opt.img_size),
    tv.transforms.ToTensor(),
    tv.transforms.Lambda(lambda x: x * 255)
])
dataset = tv.datasets.ImageFolder(opt.root_data, transforms)
dataloader = t.utils.data.DataLoader(
    dataset, opt.batch_size, shuffle=True, num_workers=opt.num_workers)

# define model
vgg = LossNet(requires_grad=False).to(device)  # loss net, no auto-grad
model = TransformNet().to(device)  # transform net

# load pre-trained model if necessary
if opt.root_model:
    model.load(opt.root_model, device=device)

# optimizer
optimizer = t.optim.Adam(model.parameters(), opt.lr)

# loss meters
meter_style = tnt.meter.AverageValueMeter()
meter_content = tnt.meter.AverageValueMeter()

# load style sample image
img_style = loadStyleImg(opt.path_style).to(device)

# caculate gram matrix of style image
features_style = vgg(img_style)
gram_style = [getGramMatrix(feature) for feature in features_style]

## Training entry

In [None]:
data_len = round(len(dataset) / opt.batch_size)

for epoch in range(opt.max_epoch):
    meter_style.reset()
    meter_content.reset()

    for index, (img_batch, _) in tqdm(
            enumerate(dataloader), desc='epoch ' + str(epoch), total=data_len):

        # training
        optimizer.zero_grad()

        img_batch = img_batch.to(device)
        img_result = model(img_batch)

        # compute loss
        features_in = vgg(normalizeBatch(img_batch))
        features_out = vgg(normalizeBatch(img_result))

        loss_content = opt.w_content * fun.mse_loss(features_in.relu3_3,
                                                    features_out.relu3_3)
        loss_style = 0.0
        for layer_out, gram_ref in zip(features_out, gram_style):
            gram_out = getGramMatrix(layer_out)
            loss_style += fun.mse_loss(gram_out, gram_ref.expand_as(gram_out))
        loss_style *= opt.w_style
        loss_total = loss_content + loss_style

        # backpropagation
        loss_total.backward()
        optimizer.step()

        # smooth loss for logging
        meter_content.add(loss_content.item())
        meter_style.add(loss_style.item())

        # visualize results
        if (index + 1) % opt.freq_plot == 0:
            # plot loss
            vis.plot('Content Loss', meter_content.value()[0])
            vis.plot('Style Loss', meter_style.value()[0])

            # show image
            vis.img('input', (img_batch[0]))
            vis.img('output', (img_result[0]))

    # save model
    model.save()