In [1]:
from Model import *
from Image import *

In [2]:
import time

import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# desired size of the output im
imsize = 512 if torch.cuda.is_available() else 256  # use small size if no gpu

In [4]:
# desired depth layers to compute style/content losses:
default_style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

In [None]:
def eval_model(pre_model, img_1, img_2,
                   default_mean_std = True,
                   style_layers=default_style_layers,
                   weight = 1000000):

    cnn = copy.deepcopy(pre_model)

    # normalization module
    normalization = Normalization(default_mean_std = default_mean_std)

    style_losses = 0

    # create our model
    model = nn.Sequential(normalization)

    # increment every time we see a conv
    i = 0  
    # go through all the layers
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            # According to Alexis Jacq, the in-place version doesn't play 
            # very nicely with the ContentLoss with the ContentLoss and StyleLoss 
            # we insert below. So we replace with out-of-place ones here.
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'maxpool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)

        model.add_module(name, layer)

        if name in style_layers:
            # add style loss:
            # calculate target style
            target_style = model(style_img).detach()
            # save target style
            style_loss = StyleLoss(target_style)
            # save the loss
            style_losses += style_loss.loss / len(style_layers)


    return style_losses * weight