In [122]:
import torch
import torch.nn as nn
import torch.optim as optim

from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image

import matplotlib.pyplot as plt
from tqdm import tqdm

In [123]:
model=models.vgg19(pretrained=True).features

In [124]:
print(model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [125]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG,self).__init__()
        
        torch.autograd.set_detect_anomaly(True)
        self.chosen_features=['0','5','10','19','28']
        self.model=models.vgg19(pretrained=True).features[:29]
        
    def forward(self,x):
        features=[]
        
        for idx,layer in enumerate(self.model):
            x=layer(x)
            
            if str(idx) in self.chosen_features:
                features.append(x)
            
        return features
    

In [126]:
model=VGG().to(device).eval() #eval to freeze weights

In [127]:
def load_image(img_name):
    image=Image.open(img_name)
    image=image.resize((img_size,img_size))
    image = transform(image).unsqueeze(0)
    return image.to(device)

In [128]:
device=torch.device("cuda" if torch.cuda.is_available else "cpu")
device

device(type='cuda')

In [135]:
img_size=300

transform = transforms.Compose([transforms.ToTensor()])
        #transforms.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225))])

loader=transform

In [130]:
original_img=load_image("saranga.png")
style_img=load_image("monalisa.png")

In [131]:
#generated=torch.randn(original_img.shape,device=device,requires_grd=True)
generated=original_img.clone().requires_grad_(True)

In [132]:
total_steps=6000
lr=0.001
alpha=1
beta=0.01
optimizer=optim.Adam([generated],lr=lr)

In [136]:
for step in tqdm(range(total_steps)):
        
        # Extract multiple(5) conv feature vectors
        target_features = model(generated)
        content_features = model(original_img)
        style_features = model(style_img)

        style_loss = 0
        content_loss = 0
        for f1, f2, f3 in zip(target_features, content_features, style_features):
            # Compute content loss with target and content images
            content_loss += torch.mean((f1 - f2)**2)

            # Reshape convolutional feature maps
            _, c, h, w = f1.size()
            f1 = f1.view(c, h * w)
            f3 = f3.view(c, h * w)

            # Compute gram matrix
            G = torch.mm(f1, f1.t())
            A = torch.mm(f3, f3.t())

            # Compute style loss with target and style images
            style_loss += torch.mean((G - A)**2) / (c * h * w) 
        
        # Compute total loss, backprop and optimize
        total_loss = alpha*content_loss + beta * style_loss 
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        
        if step%200==0:
            print(total_loss)
            save_image(generated,"generated.png")

  0%|                                                                               | 1/6000 [00:03<5:02:28,  3.03s/it]

tensor(9.0440, device='cuda:0', grad_fn=<AddBackward0>)


  0%|                                                                               | 3/6000 [00:12<6:42:02,  4.02s/it]


KeyboardInterrupt: 