# AI Community @ Семинар № 8, весна 2018
## Fast Style transfer.

В этой тетради мы рассмотрим способ быстрого переноса стилей, основная идея которого основана на том, что мы сначала предобучаем модель на определённом стиле, а затем одним проходом по сети получаем стилизованное изображение 

In [1]:
import argparse
import os
import sys
import time
import re

import numpy as np
import torch
import cv2
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import torch.onnx

import fast_neural_style.neural_style.utils as utils
from fast_neural_style.neural_style.transformer_net import TransformerNet
from fast_neural_style.neural_style.vgg import Vgg16

In [2]:
available_models = {
    'candy': 'fast_neural_style/saved_models/candy.pth',
    'mosaic': 'fast_neural_style/saved_models/mosaic.pth',
    'rain_princess': 'fast_neural_style/saved_models/rain_princess.pth',
    'udnie': 'fast_neural_style/saved_models/udnie.pth',
    'sun': 'fast_neural_style/saved_models/sun.pth'
}

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
write_video = True
stype_type = 'mosaic'
# path_to_your_video = 'your_video.mp4'

In [4]:
style_model = TransformerNet()
state_dict = torch.load(available_models[stype_type])
# remove saved deprecated running_* keys in InstanceNorm from the checkpoint
for k in list(state_dict.keys()):
    if re.search(r'in\d+\.running_(mean|var)$', k):
        del state_dict[k]
style_model.load_state_dict(state_dict)
style_model.to(device)


content_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.mul(255))
])

In [5]:
def stylize_video(path_to_video=None, write_video=False):
    if path_to_video:
        cap = cv2.VideoCapture(path_to_your_video)
    else:
        cap = cv2.VideoCapture(0)

    if write_video:
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        ret, frame = cap.read()
        out = cv2.VideoWriter('output.mp4', fourcc, 10, frame.shape[:2][::-1])

    i = 0
    while(True):
        _, frame = cap.read()
        if frame is None:
            break
        i += 1

        content_image = content_transform(frame)
        content_image = content_image.unsqueeze(0).to(device)

        with torch.no_grad():
            output = style_model(content_image).cpu()[0]
        img = output.clone().clamp(0, 255).numpy()
        img = img.transpose(1, 2, 0).astype("uint8")

        if write_video:
            out.write(img[:, :, ::-1])
        cv2.imshow('frame', img[:, :, ::-1])
        if cv2.waitKey(1) & 0xFF == ord('q'):
            print('Streaming has ended')
            break

    cap.release()
    if write_video:
        out.release()
    cv2.destroyAllWindows()


def stylize_image(img_path):
    img = cv2.imread(img_path)
    content_image = content_transform(img)
    content_image = content_image.unsqueeze(0).to(device)
    with torch.no_grad():
        output = style_model(content_image).cpu()[0]
    img = output.clone().clamp(0, 255).numpy()
    img = img.transpose(1, 2, 0).astype("uint8")
    
    
    cv2.imshow('stylized image', img[:, :, ::-1])
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    
    
def generate_texture(img_size=(512, 512)):
    img = np.random.randint(0, 255, img_size+(3,), dtype='uint8')
    content_image = content_transform(img)
    content_image = content_image.unsqueeze(0).to(device)
    with torch.no_grad():
        output = style_model(content_image).cpu()[0]
    img = output.clone().clamp(0, 255).numpy()
    img = img.transpose(1, 2, 0).astype("uint8")
    
    
    cv2.imshow('stylized image', img[:, :, ::-1])
    cv2.waitKey(0)
    cv2.destroyAllWindows()

In [None]:
stylize_video()

In [12]:
stylize_image('fast_neural_style/images/content-images/cat.jpg')

In [None]:
generate_texture((1024, 2048))