In [None]:
class FFM(object):
    def __init__(self, style, content, pastiche, alpha, beta):
        super(FFM, self).__init__()
        
        self.style = style
        self.content = content
        self.pastiche = nn.Parameter(pastiche.data)

        self.content_layers = ['conv_4']
        self.style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
        self.contentWeight = alpha
        self.styleWeight = beta

        self.transform_network = nn.Sequential(nn.ReflectionPad2d(40),
                                           nn.Conv2d(3, 32, 9, stride=1, padding=4),
                                           nn.Conv2d(32, 64, 3, stride=2, padding=1),
                                           nn.Conv2d(64, 128, 3, stride=2, padding=1),
                                           nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                           nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                           nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                           nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                           nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                           nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                           nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                           nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                           nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                           nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                           nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
                                           nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
                                           nn.Conv2d(32, 3, 9, stride=1, padding=4),
                                           )
        self.gram = GramMatrix()
        self.loss = nn.MSELoss()
        self.optimizer = optim.Adam(self.transform_network.parameters(), lr=1e-3)
        
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.loss_network.cuda()
            self.gram.cuda()
        else:
            self.loss_network.cpu()
            self.gram.cpu()

    def train(self,content):
        self.optimizer.zero_grad()

        content = content.clone()
        style = self.style.clone()
        pastiche = self.transformation_network.forward(content)

        content_loss = 0
        style_loss = 0

        i = 1
        not_inplace = lambda layer: nn.ReLU(inplace=False) if isinstance(layer, nn.ReLU) else layer
        for layer in list(self.loss_network.features):
            layer = not_inplace(layer)
            if self.use_cuda:
                layer.cuda()

            pastiche, content, style = layer.forward(pastiche), layer.forward(content), layer.forward(style)

            if isinstance(layer, nn.Conv2d):
                name = "conv_" + str(i)

                if name in self.content_layers:
                    content_loss += self.loss(pastiche * self.content_weight, content.detach() * self.content_weight)
                if name in self.style_layers:
                    pastiche_g, style_g = self.gram.forward(pastiche), self.gram.forward(style)
                    style_loss += self.loss(pastiche_g * self.style_weight, style_g.detach() * self.style_weight)

            if isinstance(layer, nn.ReLU):
                i += 1

        total_loss = content_loss + style_loss
        total_loss.backward()

        self.optimizer.step()

        return self.pastiche

In [None]:
#PENDING
num_epochs = 3
N = 4

def main():
    style_cnn = FFM(style)

    # Contents
    wikiart = datasets.ImageFolder(root='path/to/wikiart', transform=loader)
    content_loader = torch.utils.data.DataLoader(wikiart, batch_size=N, shuffle=True, **kwargs)

    for epoch in range(num_epochs):
        for i, content_batch in enumerate(content_loader):
          iteration = epoch * i + i
          content_loss, style_loss, pastiches = style_cnn.train(content_batch, style_batch)

          if i % 10 == 0:
              print("Iteration: %d" % (iteration))
              print("Content loss: %f" % (content_loss.data[0]))
              print("Style loss: %f" % (style_loss.data[0]))

          if i % 500 == 0:
              path = "outputs/%d_" % (iteration)
              paths = [path + str(n) + ".png" for n in range(N)]
              save_images(pastiches, paths)

              path = "outputs/content_%d_" % (iteration)
              paths = [path + str(n) + ".png" for n in range(N)]
              save_images(content_batch, paths)
              style_cnn.save()
