Skip to content

Commit

Permalink
Stylization, Readme
Browse files Browse the repository at this point in the history
  • Loading branch information
DmitryUlyanov committed Mar 11, 2016
1 parent 0684878 commit c93cd6a
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 24 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -9,3 +9,4 @@ data/*/*
!data/out/.gitkeep

!data/textures
!data/textures/*
60 changes: 45 additions & 15 deletions README.md
@@ -1,6 +1,6 @@
Learn a neural network from one image!

In [our paper](http://arxiv.org/abs/1603.03417) we describe a faster way to generate textures and stylize images. It requires learning a feedforward generator with a loss function proposed by [Gatys et. al.](http://arxiv.org/abs/1505.07376) which takes in our experiments about an hour or two. When generator is trained, a texture sample of any size can be generated instantly.
In [our paper](http://arxiv.org/abs/1603.03417) we describe a faster way to generate textures and stylize images. It requires learning a feedforward generator with a loss function proposed by [Gatys et. al.](http://arxiv.org/abs/1505.07376) which takes in our experiments about an hour or two. When the model is trained, a texture sample of any size can be generated instantly.

## Pretrained models
You can find two `iTorch` notebooks as well as 8 pretrained models in `supplementary` directory. You need a GPU (nn.SpatialBatchNormalization throws a error in CPU mode), `torch`, and `iTorch` installed to try them.
Expand Down Expand Up @@ -54,29 +54,59 @@ th texture_sample.lua -model data/out/model.t7 -noise_depth 3 -sample_size 512

## Stylization

The code will be uploaded shortly.

### Prepare

We used ILSVRC2012 validation set to train a generator. One pass through the data was more than enough for the model described in the paper.

Extract content from `relu4_2` layer.
```
th scripts/extract4_2.lua -images_path <path/ILSVRC2012>
```
### Train

Use this command to learn a generator to sylize like in the next example.
```
th stylization_train.lua -style_image data/textures/cezanne.jpg -train_hdf5 <path/to/generated/hdf5> -noise_depth 3 -model_name pyramid -normalize_gradients -train_images_path <path/to/ILSVRC2012> -content_weight 0.8
```
### Process
TODO

Examples:
![Sample](data/readme_pics/tiger.jpg)
### Example

![Cezanne](data/textures/cezanne.jpg)

![Original](data/readme_pics/kitty.jpg)

![Processed](data/readme_pics/kitty_cezanne.jpg)

#### Variations
We were not able to archive similar results to original parer of L. Gatys on artistic syle, which is partially explained by balance problem (read the paper for the details). Yet, while not transferring the style exactly as expected, models produce nice pictures. We tried several hacks to redefine the objective function, which could be more suitable for convolutional parametric generator, none of them worked considerably better, but the results were nice.

For the next pair we used a generator, trained using 16 images only. It is funny, that it did not overfit. Also, in this setting the net does not degrade for much longer time if zero padding is used. Note that, tiger image was not in the train set.

![Tiger](data/readme_pics/tiger.jpg)

![Tiger_processed](data/readme_pics/tiger_starry.jpg)
Using "Starry night" by Van Gogh. It takes about quarter of second to process an image at `1024 x 768` resolution.

![Sample](data/readme_pics/tiger_starry.jpg)
It is much different to what original algorithm would generate, still quite nice.

In one of the experiments the generator failed to learn Van Gogh, but went very stylish.

![Sample](data/readme_pics/pseudo.png)
![Pseudo](data/readme_pics/pseudo.png)

### Prepare
This model tried to fit both texture and content losses on a fixed set of 16 images and only content loss on the big number of images.

Extract content from `relu4_2` layer. We used Imagenet validation set.
```
th scripts/extract4_2.lua -gpu 1 -images_path path/to/image/dir
```

# Hardware
- Was tested with 12Gb NVIDIA Tesla K40m GPU and Ubuntu 14.04.
- You may decrease `batch_size`, image_size`, `noise_depth` if the model do not fit your GPU memory.
- `pyramid2` is much more memory efficient than `pyramid`, more, you can decrease the number of filters in there.
- The pretrained models do not need much memory to sample.

#
Was tested with 12Gb Nvidia Tesla K40m GPU and Ubuntu 14.04.
# Credits

The code is based on [Justin Johnsons great code](https://github.com/jcjohnson/neural-style) for artistic style.

The code is based on [Justin Johnsons great code](https://github.com/jcjohnson/neural-style) for artistic style.
The work is supported by Yandex.
Binary file modified data/readme_pics/kitty.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/readme_pics/kitty_cezanne.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/textures/cezanne.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 9 additions & 9 deletions stylization_train.lua
Expand Up @@ -52,7 +52,7 @@ params.texture = params.style_image

if params.backend == 'cudnn' then
require 'cudnn'
-- cudnn.benchmark = true
cudnn.benchmark = true
backend = cudnn
else
backend = nn
Expand Down Expand Up @@ -87,13 +87,13 @@ end

local train_hdf5 = hdf5.open(params.train_hdf5, 'r')

-- allocate reusable space
-- Allocate reusable space
inputs_batch = torch.Tensor(params.batch_size, net_input_depth, params.image_size, params.image_size)
contents_batch = torch.Tensor(params.batch_size, 512, params.image_size/8, params.image_size/8)

cur_index_train = 1
function get_input_train()
-- ignore last for simplicity
-- Ignore last for simplicity
if cur_index_train > #image_names - params.batch_size then
cur_index_train = 1
end
Expand All @@ -113,7 +113,7 @@ end

iteration = 0

-- dummy storage, this will not be changed during training
-- Dummy storage, this will not be changed during training
inputs_batch = torch.Tensor(params.batch_size, net_input_depth, params.image_size, params.image_size):uniform():cuda()

local parameters, gradParameters = net:getParameters()
Expand All @@ -126,21 +126,21 @@ function feval(x)
end
gradParameters:zero()

-- get batch
-- Get batch
local images, contents = get_input_train()

-- set current `relu4_2` content
-- Set current `relu4_2` content
content_losses[1].target = contents

-- forward
-- Forward
local out = net:forward(images)
descriptor_net:forward(out)

-- backward
-- Backward
local grad = descriptor_net:backward(out, nil)
net:backward(images, grad)

-- collect loss
-- Collect loss
local loss = 0
for _, mod in ipairs(texture_losses) do
loss = loss + mod.loss
Expand Down

0 comments on commit c93cd6a

Please sign in to comment.