In [None]:
import torch
import os
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image

In [None]:
image_size = 356
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

In [None]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.chosen_layers = ['0', '5', '10', '19', '28']
        self.model = models.vgg19(pretrained=True).features[:29]
        
    def forward(self, x):
        features = []
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            if str(layer_num) in self.chosen_layers:
                features.append(x)
                
        return features

In [None]:
loader = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor()
    ]
)

def load_image(image_name, root_dir='../data/'):
    image = Image.open(os.path.join(root_dir, image_name)).convert('RGB')
    image = loader(image).unsqueeze(0)
    return image.to(DEVICE)

In [None]:
original_image = load_image('ori.png')
style_image = load_image('sty.png')
generated = original_image.clone().requires_grad_(True)

In [None]:
print(original_image.shape)
print(style_image.shape)
print(generated.shape)

In [None]:
total_steps = 6000
learning_rate = 1e-3
alpha=1
beta=0.1
model = VGG().to(DEVICE).eval()
optimizer = torch.optim.Adam([generated], lr=learning_rate)

In [None]:
for step in range(total_steps):
    generated_features = model(generated)
    original_img_features = model(original_image)
    style_features = model(style_image)
    
    style_loss = original_loss = 0
    
    for gen_feature, orig_feature, style_feature in zip(generated_features, original_img_features, style_features):
        
        batch_size, channel, height, width = gen_feature.shape
        original_loss += torch.mean((gen_feature - orig_feature)**2)
        
        ## Compute Gram Matrix
        G = gen_feature.view(channel, height*width).mm(gen_feature.view(channel, height*width).t())
        A = style_feature.view(channel, height*width).mm(style_feature.view(channel, height*width).t())
        
        style_loss += torch.mean((G - A)**2)
        
    total_loss = alpha*original_loss + beta*style_loss
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if step%100 == 0:
        print(f'step - {step}/{total_steps} || loss - {total_loss}')
        
    if step%1000 == 0:
        save_image(generated, '{}_generated_image.png'.format(step))