<a href="https://colab.research.google.com/github/Giogia/gatys_piu_bello/blob/master/Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import tensorflow as tf
import numpy as np
import os
from pathlib import Path
from functools import reduce
from operator import mul
from google.colab import files
from tensorflow.keras import Model
from numpy import expand_dims, array
from PIL import Image
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications import vgg19
from tensorflow.keras.applications.vgg19 import preprocess_input
from IPython.display import HTML, display

In [0]:
STYLE_LAYERS = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1']
CONTENT_LAYERS = ['block4_conv2']

def net_pro(img):
    f = get_model()
    return f(img)

def get_model():
    model = vgg19.VGG19()
    model.trainable = False
    style_feature = [model.get_layer(i).output for i in STYLE_LAYERS]
    content_feature = [model.get_layer(i).output for i in CONTENT_LAYERS]
    return Model(model.input, style_feature + content_feature)
  
def get_feat_style(net, layer):
    return net[STYLE_LAYERS.index(layer) + len(CONTENT_LAYERS)]
  
def get_feat_content(net, layer):
    return net[CONTENT_LAYERS.index(layer)]

In [0]:
WEIGHTS_INIT_STDEV = .1


def net(image):

    conv1 = conv_layer(image, 32, 9, 1)
    conv2 = conv_layer(conv1, 64, 3, 2)
    conv3 = conv_layer(conv2, 128, 3, 2)
    resid1 = residual_block(conv3, 3)
    resid2 = residual_block(resid1, 3)
    resid3 = residual_block(resid2, 3)
    resid4 = residual_block(resid3, 3)
    resid5 = residual_block(resid4, 3)
    conv_t1 = conv_tranpose_layer(resid5, 64, 3, 2)
    conv_t2 = conv_tranpose_layer(conv_t1, 32, 3, 2)
    conv_t3 = conv_layer(conv_t2, 3, 9, 1, is_relu=False)
    preds = tf.nn.tanh(conv_t3) * 150 + 255./2

    return preds


def conv_layer(image, filter_number, filter_size, strides, is_relu=True):

    # make the convolution of the image and return the convolution

    weights_initialization = conv_initialization_vars(image, filter_number, filter_size)
    strides_shape = [1, strides, strides, 1]

    # apply the filter to the image with a 2d convolution
    image = tf.nn.conv2d(image, weights_initialization, strides_shape, padding='SAME')
    image = _instance_norm(image)

    if is_relu:
        image = tf.nn.relu(image)

    return image


def conv_tranpose_layer(img, filter_number, filter_size, strides):

    weights_initialized = conv_initialization_vars(img, filter_number, filter_size, transpose=True)

    batch_size, rows, cols, in_channels = [i for i in img.get_shape()]
    new_rows, new_cols = int(rows * strides), int(cols * strides)
    new_shape = [batch_size, new_rows, new_cols, filter_number]

    tf_shape = tf.stack(new_shape)
    strides_shape = [1, strides, strides, 1]

    convolution = tf.nn.conv2d_transpose(img, weights_initialized, tf_shape, strides_shape, padding='SAME')
    convolution = _instance_norm(convolution)

    return tf.nn.relu(convolution)


def residual_block(img, filter_size=3):

    tmp_convolution = conv_layer(img, 128, filter_size, 1)

    # add the convolution to the original image
    return img + conv_layer(tmp_convolution, 128, filter_size, 1, is_relu=False)


def _instance_norm(img):

    # set the shape of the input img
    batch_size, rows, cols, in_channels = [i for i in img.get_shape()]
    var_shape = [in_channels]

    # calculate the mean and the variance of the img
    mu, sigma_sq = tf.nn.moments(img, [1, 2], keep_dims=True)
    shift = tf.Variable(tf.zeros(var_shape))
    scale = tf.Variable(tf.ones(var_shape))
    epsilon = 1e-3

    # normalize the img input wrt the mean and the variance calculated
    normalized = (img - mu) / (sigma_sq + epsilon) ** (.5)

    return scale * normalized + shift


def conv_initialization_vars(network, out_channels, filter_size, transpose=False):
    
    _, rows, cols, in_channels = [i.value for i in network.get_shape()]

    if not transpose:
        weights_shape = [filter_size, filter_size, in_channels, out_channels]
    else:
        weights_shape = [filter_size, filter_size, out_channels, in_channels]

    # with tf truncated we output rnd values rom a truncated normal distribution
    weights_init = tf.Variable(tf.truncated_normal(weights_shape, stddev=WEIGHTS_INIT_STDEV, seed=1), dtype=tf.float32)

    return weights_init

In [0]:
def load_image(path):

    max_dim = 1024

    img = Image.open(path)

    # resize image to max_dim
    scale = max_dim / max(img.size)

    if scale < 1:
        scaled_width = round(img.size[0] * scale)
        scaled_height = round(img.size[1] * scale)
        img = img.resize((scaled_width, scaled_height))

    #convert greyscale to rgb
    img = img.convert("RGB")

    img = img_to_array(img)

    return img


def preprocess_image(img):

    img = expand_dims(img, axis=0)

    #normalize by mean = [103.939, 116.779, 123.68] and with channels BGR
    img = preprocess_input(img)

    return img
  

def get_img(src, img_size=False):
    img = scipy.misc.imread(src, mode='RGB') # misc.imresize(, (256, 256, 3))
    if not (len(img.shape) == 3 and img.shape[2] == 3):
        img = np.dstack((img,img,img))
    if img_size != False:
        img = scipy.misc.imresize(img, img_size)
    return img

In [0]:
M_PIXEL = np.array([123.68,  116.779,  103.939])


def get_style_loss(layers, batch_size, style_features, style_weight, net):
    style_losses = []
    for style_layer in layers:
        layer = get_feat_style(net, style_layer)
        bs, height, width, filters = map(lambda i: i.value, layer.get_shape())
        size = height * width * filters
        feats = tf.reshape(layer, (bs, height * width, filters))
        feats_t = tf.transpose(feats, perm=[0, 2, 1])
        grams = tf.matmul(feats_t, feats) / size
        style_gram = style_features[style_layer]
        style_losses.append(2 * tf.nn.l2_loss(grams - style_gram) / style_gram.shape)
    style_loss = style_weight * reduce(tf.add, style_losses) / batch_size
    return style_loss


def get_content_loss(layer, batch_size, content_features, content_weight, net):
    lay = get_feat_content(net, layer)
    content_size = t_size(content_features[layer]) * batch_size
    assert t_size(content_features[layer]) == t_size(lay)
    d_content = lay - content_features[layer]
    content_loss = content_weight * 2 * tf.nn.l2_loss(d_content) / content_size
    return content_loss


def t_size(tensor):
    return reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1)


def compute_loss(content_ph, style_target, weights, tv_weight, layers, batch_size, batch_shape, gen_network_vgg):
    style_features = precompute_style_features(layers[1], style_target, gen_network_vgg)
    content_features = get_content_net_and_features(layers[0], content_ph, gen_network_vgg)
    content_ph = content_ph / 255.0
    net = generate_net(content_ph, gen_network_vgg)
    content_loss = get_content_loss(layers[0][0], batch_size, content_features, weights[0], net)
    style_loss = get_style_loss(layers[1], batch_size, style_features, weights[1], net)

    # total variation denoising
    tv_y_size = t_size(content_ph[:, 1:, :, :])
    tv_x_size = t_size(content_ph[:, :, 1:, :])
    y_tv = tf.nn.l2_loss(content_ph[:, 1:, :, :] - content_ph[:, :batch_shape[1]-1, :, :])
    x_tv = tf.nn.l2_loss(content_ph[:, :, 1:, :] - content_ph[:, :, :batch_shape[2]-1, :])
    tv_loss = tv_weight*2*(x_tv/tv_x_size + y_tv/tv_y_size)/batch_size

    return content_loss + style_loss + tv_loss


def get_content_net_and_features(layers, content_ph, vgg):
    # pre-compute content features
    content_features = {}
    content_net = vgg(content_ph - M_PIXEL)
    for layer in layers:
        content_features[layer] = get_feat_content(content_net, layer)
    return content_features

def precompute_style_features(layers, style_target, vgg):
    style_features = {}
    style_image = tf.placeholder(tf.float32, style_target.shape, name='style_image')
    net = vgg(style_image - M_PIXEL)
    style_pre = np.array(style_target)
    for layer in layers:
        features = get_feat_style(net, layer).eval(feed_dict={style_image: style_pre})
        features = np.reshape(features, (-1, features.shape[3]))
        gram = np.matmul(features.T, features) / features.size
        style_features[layer] = gram
    return style_features

def generate_net(content_ph, vgg):
    return vgg(net(content_ph) - M_PIXEL)

In [0]:
def optimize(content_folder, style, content_weight, style_weight, tv_weight, vgg, 
             epochs=2, batch_size=4, save_path='fns.ckpt', learning_rate=1e-3):
    style_target = load_image(style)
    style_target = preprocess_image(style_target)
    content_targets = get_files(content_folder)
    
    batch_shape = (batch_size, 256, 256, 3)
    layers = CONTENT_LAYERS, STYLE_LAYERS
    weights = content_weight, style_weight

    with tf.Graph().as_default(), tf.Session() as sess:
        x_content = tf.placeholder(tf.float32, shape=batch_shape, name="x_content")
        # overall loss
        loss = compute_loss(x_content, style_target, weights, tv_weight, layers, batch_size, batch_shape, vgg)
        train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss)
        sess.run(tf.global_variables_initializer())
        print("started training")
        progressbar = display(progress(0,epochs), display_id=True)
        for epoch in range(epochs):
            sample = len(content_targets)
            it = 0
            while it * batch_size < sample:
                curr = it * batch_size
                step = curr + batch_size
                x_batch = [get_img(img_p, (256, 256, 3)).astype(np.float32) for img_p in content_targets[curr:step]]
                x_batch = np.array(x_batch, dtype=np.float32)
                train_step.run(feed_dict={x_content: x_batch})
                it += 1
            progressbar.update(progress(it%epochs, epochs))
        saver = tf.train.Saver()
        saver.save(sess, save_path)

In [0]:
def progress(value, max=100):
    return HTML("""
        <progress
            value='{value}'
            max='{max}',
            style='width: 100%'
        >
            {value}
        </progress>
    """.format(value=value, max=max))

In [0]:
PATH = os.getcwd()
IMAGES_PATH = PATH + '/Images'
CONTENT_IMAGE_PATH = IMAGES_PATH + '/Content'


def find_file(filename, directory):
    for file in os.listdir(directory):
        if os.path.splitext(file)[0] == filename:
            return file

    return None

    
def create_directory(path):
    
    if not(os.path.isdir(path)):
            os.mkdir(path)
  
  
def upload_files(path,number, message):
    
    while(len(os.listdir(path))<number):
            print("Please upload at least "+message)
            os.chdir(path)
            files.upload()
            os.chdir(PATH)
            time.sleep(30)
    
      
def list_files(path):
  
    for file in os.listdir(Path(path)):
            print(os.path.splitext(file)[0])
        
def get_files(path):
    return [os.path.join(path,os.path.splitext(file)[0] + 
                         os.path.splitext(file)[1]) for file in os.listdir(Path(path))]

def check_error(var,path):
  
    while(var==None):
            print("Please insert a correct Name (case sensitive input)")
            list_files(path)
            var = find_file(input(),Path(path))
            
    return var
  
if __name__ == "__main__":

    create_directory(IMAGES_PATH)
    create_directory(CONTENT_IMAGE_PATH)
    

    print("Upload Content Images:")
    
    upload_files(CONTENT_IMAGE_PATH,10000,"10000 images")
    
    print("Using the following images for content training:")
    list_files(CONTENT_IMAGE_PATH)
    
    content_path  = Path(CONTENT_IMAGE_PATH + '/')

    print("Select Style Image:")
    
    upload_files(IMAGES_PATH,2,"style")
    
    list_files(IMAGES_PATH)

    style = find_file(input(),Path(IMAGES_PATH))
    style = check_error(style,IMAGES_PATH)
    style_path = Path(IMAGES_PATH + "/" + style)

    output =  'Model_' + os.path.splitext(style)[0]
    output_path = Path(PATH +"/"+ output + ".ckpt")
    open(output + ".ckpt", "wb").close
    content_weight = 7.5e0
    style_weight = 1e2
    tv_weight = 2e2
    batch_size = 4
    epochs = 1000
    learning_rate=1e-3
    optimize(content_path, style_path, content_weight, style_weight, tv_weight, 
             net_pro, epochs, batch_size, output_path, learning_rate)
    
    print("train model saved in Files folder, refresh Files to see it")