<a href="https://colab.research.google.com/github/ptran1203/style_transfer/blob/master/Adain_Style.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Utils

In [1]:
import pickle
import numpy as np
import urllib.request
import keras.preprocessing.image as image_processing
import cv2
try:
    from google.colab.patches import cv2_imshow
except ImportError:
    from cv2 import imshow as cv2_imshow

MEAN_PIXCELS = np.array([103.939, 116.779, 123.68])

def pickle_save(object, path, log=False):
    try:
        log and print('save data to {} successfully'.format(path))
        with open(path, "wb") as f:
            return pickle.dump(object, f)
    except:
        log and print('save data to {} failed'.format(path))


def pickle_load(path, log=False):
    try:
        log and print("Loading data from {} - ".format(path))
        with open(path, "rb") as f:
            data = pickle.load(f)
            log and print("DONE")
            return data
    except Exception as e:
        print(str(e))
        return None

def norm(imgs):
    return (imgs - 127.5) / 127.5


def de_norm(imgs):
    return imgs * 127.5 + 127.5


def preprocess(imgs):
    """
    BGR -> RBG then subtract the mean
    """
    return imgs - MEAN_PIXCELS
    return imgs[...,[2,1,0]] - MEAN_PIXCELS


def deprocess(imgs):
    return imgs + MEAN_PIXCELS
    return (imgs + MEAN_PIXCELS)[...,[2,1,0]]


def show_images(img_array, denorm=True, deprcs=True):
    shape = img_array.shape
    img_array = img_array.reshape(
        (-1, shape[-4], shape[-3], shape[-2], shape[-1])
    )
    # convert 1 channel to 3 channels
    channels = img_array.shape[-1]
    resolution = img_array.shape[2]
    img_rows = img_array.shape[0]
    img_cols = img_array.shape[1]

    img = np.full([resolution * img_rows, resolution * img_cols, channels], 0.0)
    for r in range(img_rows):
        for c in range(img_cols):
            img[
            (resolution * r): (resolution * (r + 1)),
            (resolution * (c % 10)): (resolution * ((c % 10) + 1)),
            :] = img_array[r, c]

    if denorm:
        img = de_norm(img)
    if deprcs:
        img = deprocess(img)

    cv2_imshow(img)


def http_get_img(url, rst=64, gray=False, normalize=True):
    req = urllib.request.urlopen(url)
    arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
    img = cv2.imdecode(arr, -1)
    if rst is not None:
        img = image_resize(img, rst)
    if gray:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    
    img = np.expand_dims(img, 0)
    if normalize:
        img = norm(preprocess(img))

    return img


def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA):
    dim = None
    (h, w) = image.shape[:2]
    if width is None and height is None:
        return image

    if width is None:
        r = height / float(h)
        dim = (int(w * r), height)

    else:
        r = width / float(w)
        dim = (width, int(h * r))
    resized = cv2.resize(image, dim, interpolation = inter)
    return resized


dataloader

In [2]:
import numpy as np
#import utils
from collections import Counter
import os
from sklearn.model_selection import train_test_split
try:
    from google.colab.patches import cv2_imshow
except ImportError:
    from cv2 import imshow as cv2_imshow

class DataGenerator:
    def __init__(self, base_dir, batch_size, rst, max_size=500,
    multi_batch=False, normalize=True, preprocessing=True):
        BATCH_FILES = 4
        self.base_dir = base_dir
        self.batch_size = batch_size
        self.id = 1
        self.rst = rst
        self.multi_batch = multi_batch
        self.normalize = normalize
        self.max_size = max_size
        self.preprocessing = preprocessing
        self.x = self.get_content_images()

        if multi_batch:
            self.y = self.get_style_images(self.id)
        else:
            self.y = self.get_style_images()

        self.max_size = max_size

        if self.preprocessing:
            self.x = preprocess(self.x)
            self.y = preprocess(self.y)

        if normalize:
            self.x = norm(self.x)
            self.y = norm(self.y)


    def get_content_images(self):
        return pickle_load(
            os.path.join(self.base_dir, 'dataset/content_imgs_{}.pkl'.format(self.rst)))[:self.max_size]


    def get_style_images(self, _id=""):
        fname = 'style_imgs_{}'.format(self.rst)

        if _id:
            fname += "_" + str(_id)

        return pickle_load(
                os.path.join(self.base_dir, 'dataset/{}.pkl'.format(fname)))[:self.max_size]


    def next_id(self):
        self.id += 1
        if self.id > self.BATCH_FILES:
            self.id = 1
        
        self.y = self.get_style_images(self.id)[:self.max_size]

        if self.preprocessing:
            self.y = preprocess(self.y)
        if self.normalize:
            self.y = norm(self.y)


    def augment_one(self, x, y):
        seed = np.random.randint(0, 100)
        new_x = transform(x, seed)
        new_y = transform(y, seed)
        return new_x, new_y


    def augment_array(self, x, y, augment_factor):
        imgs = []
        masks = []
        for i in range(len(x)):
            imgs.append(x[i])
            masks.append(y[i])
            for _ in range(augment_factor):
                _x, _y = self.augment_one(x[i], y[i])
                imgs.append(_x)
                masks.append(_y)

        return np.array(imgs), np.array(masks)


    def shuffle_style_imgs(self):
        size = len(self.y)
        indices = np.arange(size)
        np.random.shuffle(indices)
        return self.y[indices]


    def next_batch(self, augment_factor):
        if self.multi_batch:
            x = self.x
            indices = np.arange(x.shape[0])
            np.random.shuffle(indices)
            max_id = x.shape[0] - self.batch_size + 1
            print("[", end="")
            for i in range(self.BATCH_FILES):
                for start_idx in range(0, max_id, self.batch_size):
                    access_pattern = indices[start_idx:start_idx + self.batch_size]

                    yield (
                        x[access_pattern, :, :, :],
                        self.y[access_pattern],
                    )
                print("{}/6 - ".format(i+1), end="")
                self.next_id()
            print("]")
        else:
            x = self.x
            self.y = self.shuffle_style_imgs()

            indices = np.arange(x.shape[0])
            np.random.shuffle(indices)
            max_id = x.shape[0] - self.batch_size + 1
            for start_idx in range(0, max_id, self.batch_size):
                access_pattern = indices[start_idx:start_idx + self.batch_size]

                yield (
                    x[access_pattern, :, :, :],
                    self.y[access_pattern],
                )

    def get_random_sample(self, test=True):
        if test:
            idx = np.random.randint(0, self.x_test.shape - 1)
            return self.x_test[idx], self.y_test[idx]

        idx = np.random.randint(0, self.x.shape - 1)
        return self.x[idx], self.y[idx]


    def random_show(self, option='style'):
        """
        option: ['style', 'content']
        """
        idx = np.random.randint(0, self.x.shape - 1)
        if option == 'style':
            return cv2_imshow(de_norm(self.y[idx]))

        return cv2_imshow(de_norm(self.x[idx]))


    def show_imgs(self, img):
        if len(img.shape) == 4:
            return show_images(img, self.normalize, self.preprocessing)

        if self.normalize:
            img = de_norm(img)
        if self.preprocessing:
            img = deprocess(img)

        cv2_imshow(img)


model

In [3]:
import tensorflow as tf
import keras
import numpy as np
import datetime
import matplotlib.pyplot as plt
#import utils
import keras.backend as K

from keras.layers.convolutional import Conv2D
from keras.layers import Input, Activation, Layer, UpSampling2D
from keras.models import Model
from keras.optimizers import Adam
from keras.applications.vgg19 import VGG19
from keras.applications.vgg16 import VGG16

try:
    # In case run on google colab
    from google.colab.patches import cv2_imshow
except ImportError:
    from cv2 import imshow as cv2_imshow

DEFAULT_STYLE_LAYERS = [
    'block1_conv1', 'block2_conv1',
    'block3_conv1', 'block4_conv1',
]
DEFAULT_LAST_LAYER = 'block4_conv1'


class AdaptiveInstanceNorm(Layer):
    def __init__(self, epsilon=1e-3):
        super(AdaptiveInstanceNorm, self).__init__()
        self.epsilon = epsilon


    def call(self, inputs):
        x, style = inputs
        axis = [1, 2]
        x_mean = K.mean(x, axis=axis, keepdims=True)
        x_std = K.std(x, axis=axis, keepdims=True)

        style_mean = K.mean(style, axis=axis, keepdims=True)
        style_std = K.std(style, axis=axis, keepdims=True)

        norm = (x - x_mean) * (1 / (x_std + self.epsilon))

        return norm * (style_std + self.epsilon) + style_mean


    def compute_output_shape(self, input_shape):
        return input_shape[0]


class Reduction(Layer):
    def __init__(self):
        super(Reduction, self).__init__()

    def call(self, inputs):
        return tf.reduce_sum(inputs)

class StyleTransferModel:
    def __init__(self, base_dir, rst, lr,
                style_layer_names=DEFAULT_STYLE_LAYERS,
                last_layer=DEFAULT_LAST_LAYER,
                show_interval=25,
                style_loss_weight=1,
                pre_trained_model='vgg16'):
        self.base_dir = base_dir
        self.rst = rst
        self.pre_trained_model = pre_trained_model
        self.lr = lr
        self.style_layer_names = style_layer_names
        self.last_layer = last_layer
        self.show_interval = show_interval
        img_shape = (self.rst, self.rst, 3)

        # ===== Build the model ===== #
        self.encoder = self.build_encoder()
        self.style_layers = self.build_style_layers()
        content_img = Input(shape=img_shape)
        style_img = Input(shape=img_shape)

        content_feat = self.encoder(content_img)
        style_feat = self.encoder(style_img)

        combined_feat = AdaptiveInstanceNorm()([content_feat, style_feat])
        self.init_rst = K.int_shape(combined_feat)[1]
        self.decoder = self.build_decoder((self.init_rst, self.init_rst, 512))

        gen_img = self.decoder(combined_feat)
        gen_feat = self.encoder(gen_img)

        self.transfer_model = Model(inputs=[content_img, style_img],
                                    outputs=gen_img)
        content_loss = K.mean(K.square(combined_feat - gen_feat), axis=[1, 2])
        self.transfer_model.add_loss(Reduction()(content_loss))
        self.transfer_model.add_loss(style_loss_weight*self.compute_style_loss(gen_img, style_img))
        self.transfer_model.compile(optimizer=Adam(self.lr),
                                    loss=["mse"],
                                    loss_weights=[0.0])


    def compute_style_loss(self, gen_img, style_img):
        gen_feats = self.style_layers(gen_img)
        style_feats = self.style_layers(style_img)
        style_loss = []
        axis = [1, 2]
        for i in range(len(style_feats)):
            gmean = K.mean(gen_feats[i], axis=axis)
            gstd = K.std(gen_feats[i], axis=axis)

            smean = K.mean(style_feats[i], axis=axis)
            sstd = K.std(style_feats[i], axis=axis)

            style_loss.append(
                K.sum(K.square(gmean - smean)) +
                K.sum(K.square(gstd - sstd))
            )

        return Reduction()(style_loss)


    def build_style_layers(self):
        return Model(
            inputs=self.encoder.inputs,
            outputs=[self.encoder.get_layer(l).get_output_at(0) \
                for l in self.style_layer_names]
        )


    def build_encoder(self):
        input_shape = (self.rst, self.rst, 3)
        vggnet = VGG16 if self.pre_trained_model == 'vgg16' else VGG19
        model = vggnet(
            include_top=False,
            weights='imagenet',
            input_tensor=Input(input_shape),
            input_shape=input_shape,
        )
        print('Encoder: {}'.format(model.name))
        model.trainable = False
        for layer in model.layers:
            layer.trainable = False

        return Model(
            inputs=model.inputs,
            outputs=model.get_layer(self.last_layer).get_output_at(0)
        )


    def conv_block(self, x, filters, kernel_size,
                    activation='relu', up_sampling=False):

        x = Conv2D(filters, kernel_size=kernel_size, strides=1,
                    padding='same', activation=activation)(x)

        if up_sampling:
            x = UpSampling2D(size=(2, 2), interpolation='nearest')(x)

        return x


    def build_decoder(self, input_shape):
        feat = Input(input_shape)
        kernel_size = 3

        x = self.conv_block(feat, 512, kernel_size=kernel_size, up_sampling=True)

        x = self.conv_block(x, 256, kernel_size=kernel_size)
        x = self.conv_block(x, 256, kernel_size=kernel_size)
        x = self.conv_block(x, 256, kernel_size=kernel_size)
        x = self.conv_block(x, 256, kernel_size=kernel_size, up_sampling=True)

        # x = self.conv_block(x, 128, kernel_size=kernel_size)
        # x = self.conv_block(x, 128, kernel_size=kernel_size)
        x = self.conv_block(x, 128, kernel_size=kernel_size)
        x = self.conv_block(x, 128, kernel_size=kernel_size, up_sampling=True)

        x = self.conv_block(x, 64, kernel_size=kernel_size)
        x = self.conv_block(x, 64, kernel_size=kernel_size)

        style_image = self.conv_block(x, 3, kernel_size=kernel_size, activation='linear')

        model = Model(inputs=feat, outputs=style_image, name='decoder')
        return model


    @staticmethod
    def init_hist():
        return {
            "loss": [],
            "val_loss": []
        }


    def train(self, data_gen, epochs, augment_factor=0):
        history = self.init_hist()
        print("Train on {} samples".format(len(data_gen.x)))

        for e in range(epochs):
            start_time = datetime.datetime.now()
            print("Train epochs {}/{} - ".format(e + 1, epochs), end="")

            batch_loss = self.init_hist()
            for content_img, style_img in data_gen.next_batch(augment_factor):
                loss = self.transfer_model.train_on_batch([content_img, style_img],
                                                          style_img)
                batch_loss['loss'].append(loss)

            # evaluate
            # batch_loss['val_loss'] = 

            mean_loss = np.mean(np.array(batch_loss['loss']))
            mean_val_loss = 0#np.mean(np.array(batch_loss['val_loss']))

            history['loss'].append(mean_loss)
            history['val_loss'].append(mean_val_loss)

            print("Loss: {}, Val Loss: {} - {}".format(
                mean_loss, mean_val_loss,
                datetime.datetime.now() - start_time
            ))

            if e % self.show_interval == 0:
                self.save_weight()
                idx = np.random.randint(0, data_gen.max_size - 1)
                cimg, simg = data_gen.x[idx:idx+1], data_gen.y[idx:idx+1]
                gen_img = self.generate(cimg, simg)
                data_gen.show_imgs(np.concatenate([cimg, simg, gen_img]))

        self.history = history
        return history


    def plot_history(self):
        plt.plot(self.history['loss'], label='train loss')
        plt.plot(self.history['val_loss'], label='val loss')
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.title('Segmentation model')
        plt.legend()
        plt.show()


    def save_weight(self):
        try:
            self.transfer_model.save_weights(self.base_dir + '/transfer_model.h5')
        except Exception as e:
            print("Could not load model, {}".format(str(e))) 


    def load_weight(self):
        try:
            self.transfer_model.load_weights(self.base_dir + '/transfer_model.h5')
        except Exception as e:
            print("Save model failed, {}".format(str(e))) 


    def generate(self, content_imgs, style_imgs):
        return self.transfer_model.predict([content_imgs, style_imgs])


    def show_sample(self, content_img, style_img,
                    concate=True, denorm=True, deprocess=True):
        gen_img = self.generate(content_img, style_img)

        if concate:
            return show_images(np.concatenate([content_img, style_img, gen_img]), denorm, deprocess)

        if denorm:
            content_img = de_norm(content_img)
            style_img = de_norm(style_img)
            gen_img = de_norm(gen_img)
        if deprocess:
            content_img = deprocess(content_img)
            style_img = deprocess(style_img)
            gen_img = deprocess(gen_img)

        cv2_imshow(content_img[0])
        cv2_imshow(style_img[0])
        cv2_imshow(gen_img[0])


In [None]:
%tensorflow_version 1.x
from google.colab import drive, output
data_loaded = False
drive.mount('/content/drive')
BASE_DIR = "/content/drive/My Drive/Style_Transfer"
!rm -rf '/content/style_transfer'
!git clone https://github.com/ptran1203/style_transfer

In [4]:
BASE_DIR = "Style_Transfer"

In [None]:
cd style_transfer

[Errno 2] No such file or directory: 'style_transfer'
/content


In [None]:
import cv2
from google.colab.patches import cv2_imshow
import numpy as np
import pickle
import os
#import utils
#from dataloader import DataGenerator
#from model import *

class DataGen(DataGenerator):
    BATCH_FILES= 4

class SModel(StyleTransferModel):
    pass


style_layer_names=[
    'block1_conv1', 'block2_conv1',
    'block3_conv1', 'block4_conv1',
]
last_layer='block4_conv1'
pre_trained_model = 'vgg19'
rst = 256
data_gen = DataGen(BASE_DIR, 8, rst=rst, max_size=1500, multi_batch=False,
                   normalize=True)
smodel = SModel(BASE_DIR, None, 1e-4,
                style_layer_names=style_layer_names,
                last_layer=last_layer, 
                show_interval=5,
                style_loss_weight=3.5,
                pre_trained_model=pre_trained_model)

[Errno 2] No such file or directory: 'Style_Transfer/dataset/content_imgs_256.pkl'


TypeError: ignored

In [None]:
smodel.load_weight()
smodel.train(data_gen, 500, augment_factor=0)

In [None]:
urls = [
    # 'https://github.com/elleryqueenhomels/arbitrary_style_transfer/raw/master/images/style_thumb/escher_sphere_thumb.jpg',
    # 'https://github.com/elleryqueenhomels/arbitrary_style_transfer/raw/master/images/style_thumb/udnie_thumb.jpg',
    # 'https://github.com/elleryqueenhomels/arbitrary_style_transfer/raw/master/images/style_thumb/mosaic_thumb.jpg',
    # 'https://github.com/elleryqueenhomels/arbitrary_style_transfer/raw/master/images/style_thumb/cat_thumb.jpg'
    'https://github.com/lengstrom/fast-style-transfer/blob/master/examples/style/rain_princess.jpg?raw=true',
    'https://github.com/lengstrom/fast-style-transfer/blob/master/examples/style/wave.jpg?raw=true',
]
cimg = utils.http_get_img(
    'https://yt3.ggpht.com/a/AATXAJx3V2SYpa27ubB-eIw_vzBgS1QHKcBGj5xAZZ7dQQ=s900-c-k-c0xffffffff-no-rj-mo',
    # 'https://github.com/elleryqueenhomels/arbitrary_style_transfer/blob/master/images/content/stata.jpg?raw=true',
    512
)
cv2_imshow(utils.deprocess(utils.de_norm(cimg[0])))
for url in urls:
    simg = utils.http_get_img(url, 512)
    gen = smodel.generate(cimg, simg)
    
    cv2_imshow(utils.deprocess(utils.de_norm(simg[0])))
    cv2_imshow(utils.deprocess(utils.de_norm(gen[0])))