In [30]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline

if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_default_device(device)

In [17]:
class Residual_Block(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels)
        )
    
    def forward(self, x):
        return x + self.conv(x)

In [40]:
class TransferNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.downSample = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=9, stride=1, padding=4),
            nn.InstanceNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU()
        )
        self.residual = nn.Sequential(
            * [Residual_Block(128) for _ in range(5)]
        )
        self.upSample = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.downSample(x)
        x = self.residual(x)
        x = self.upSample(x)
        return x

In [19]:
class VGG(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = models.vgg19(pretrained=True).features.eval()
        for feature in self.vgg:
            feature.requires_grad_(False)

    def forward(self, x, content_layers, style_layers):
        content_features = []
        style_features = []
        for i, layer in enumerate(self.vgg):
            x = layer(x)
            if i in content_layers:
                content_features.append(x)
            elif i in style_layers:
                style_features.append(x)
        return content_features, style_features

In [34]:
def gram(x):
    B, C, H, W = x.size()
    x = x.view(B, C, -1)
    return torch.bmm(x, x.transpose(1, 2)) / (C * H * W)

def content_loss(src, tgt):
    return F.mse_loss(src, tgt.detach())

def style_loss(src, tgt):
    return F.mse_loss(gram(src), gram(tgt).detach())

In [42]:
imsize = 256
content_layers = [25]
style_layers = [0, 5, 10, 19, 28]

epochs = 1000
learning_rate = 0.01
content_loss_weight = 1
style_loss_weight = 1e6

In [46]:
loader = transforms.Compose([
    transforms.Resize(imsize),
    transforms.CenterCrop(imsize),
    transforms.ToTensor(),
])

def image_loader(image_name):
    image = Image.open(image_name)
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)

def image_show(tensor):
    image = tensor.cpu().clone()
    image = (image + 1) / 2
    image = image.squeeze(0)
    image = transforms.ToPILImage()(image)
    plt.imshow(image)
    plt.show()

In [None]:
content_img = image_loader('content.jpg')
style_img = image_loader('style.jpg')

vgg_model = VGG().to(device)
content_features, _ = vgg_model(content_img, content_layers, style_layers)
_, style_features = vgg_model(style_img, content_layers, style_layers)

transfer_model = TransferNet().to(device)
optimizer = optim.Adam(transfer_model.parameters(), lr=learning_rate)

for epoch in range(epochs):
    transfer_model.train()
    gen_img = transfer_model(content_img)
    gen_img = (gen_img + 1) / 2
    gen_content_features, gen_style_features = vgg_model(gen_img, content_layers, style_layers)
    content_l = 0
    style_l = 0
    for i in range(len(gen_content_features)):
        content_l += content_loss(gen_content_features[i], content_features[i])
    for i in range(len(gen_style_features)):
        style_l += style_loss(gen_style_features[i], style_features[i])
    l = content_loss_weight * content_l + style_loss_weight * style_l

    optimizer.zero_grad()
    l.backward()
    optimizer.step()

    if (epoch + 1) % 50 == 0:
        print('Epoch: {}, Loss: {}, Content Loss: {}, Style Loss: {}'.format(epoch + 1, l, content_l, style_l))
        image_show(gen_img.clip(-1, 1))

torch.save(transfer_model.state_dict(), 'transfer_model.pth')
    

