In [1]:
# Import the necessay modules
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
from PIL import Image

In [2]:
# Check the vgg model layers
model = models.vgg19(pretrained =True).features
model

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /Users/mohitsharma/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|████████████████████████████████████████| 548M/548M [02:11<00:00, 4.37MB/s]


Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [3]:
# Defining the network
class VGG(nn.Module):
    def __init__(self):
        super(VGG,self).__init__()
        self.chosen_features =['1','5','10','19','28']
        self.model =models.vgg19(pretrained=True).features[:29]

    def forward(self,x):
        features =[]
        for layer_num,layer in enumerate(self.model):
            x=layer(x)
            if(str(layer_num) in self.chosen_features):
                features.append(x)
        return features

In [4]:
#  Image Loading 
def load_image(image_name):
    image =Image.open(image_name)
    image =loader(image).unsqueeze(0)
    return image.to(device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_size = 256

loader =transforms.Compose(
    [
        transforms.Resize((image_size,image_size)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[],std=[])
    ]
)
model =VGG().to(device).eval()

In [6]:
# Defining the original immage ad the style image
original_image =load_image("Images/pexels-olly-846741.jpg")
style_image =load_image("Images/Ute Herrmann  Galleri Habs A_S  Midtjyllands Kunst Center.jpeg")
# generated =torch.randn(original_image.shape,device=device,requires_grad=True)
generated =original_image.clone().requires_grad_(True)

In [7]:
#  Hyperparameters
total_steps =6000
learning_rate =0.001
alpha =1
beta =0.01

# Optimizer
optimizer =optim.Adam([generated],lr =learning_rate)

In [None]:
# Training Loop
for step in range(total_steps):
    original_features =model(original_image)
    generated_features =model(generated)
    style_features =model(style_image)

    style_loss =original_loss =0
    for org_feature ,gen_feature, style_feature in zip(original_features, generated_features,style_features):
        batch_size, channel, height, width =gen_feature.shape
        original_loss += torch.mean((gen_feature -org_feature)**2)

        # Compute Gram Matrix
        G =gen_feature.view(channel, height*width).mm(
            gen_feature.view(channel, height*width).t()
        )
        A =style_feature.view(channel, height*width).mm(
            style_feature.view(channel, height*width).t()
        )
        style_loss += torch.mean((G-A)**2)

    total_loss = alpha*original_loss + beta*style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if step%100 ==0:
        print(total_loss)
        save_image(generated,"Generated.png")

tensor(82596.7891, grad_fn=<AddBackward0>)
tensor(10417.4268, grad_fn=<AddBackward0>)
tensor(6534.3311, grad_fn=<AddBackward0>)
