In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image 

In [None]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.model = models.vgg19(pretrained=True).features[:29] #Only until 28 layers
        self.chosen_features = ['0', '5', '10', '19', '28']
    
    def forward(self, x):
        features = []
        
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            
            if str(layer_num) in self.chosen_features:
                features.append(x)
                
            
        return features
    
def load_image(image_name):
    image = Image.open(image_name)
    image = loader(image).unsqueeze(0)
    
    return image.to(device)

device = torch.device('cuda' if torch.cuda.is_available else 'cpu')

image_size = 356

loader = transforms.Compose(
         [
             transforms.Resize((image_size, image_size)),
             transforms.ToTensor()
             #transforms.Normalize(mean=[], std=[])
         ]
)

content_img = load_image("Salah.jpg")
style_img = load_image("Style1.jpg")

model =  VGG().to(device).eval() ##Freezes the parameters of the network

generated_img = content_img.clone().requires_grad_(True)

# Hyperparameters
epochs = 6000
learning_rate = 0.001
alpha = 1 
beta = 0.01

optimizer = optim.Adam([generated], lr=learning_rate)

for epoch in range(epochs):
    generated_features = model(generated_img)
    content_features = model(content_img)
    style_features = model(style_img)
    
    style_loss = 0
    content_loss =  0
    
    for G, O, S in zip(generated_features, content_features, style_features):
        batch_size, channel, height, width = gen_feature.shape
        #Content loss
        content_loss += torch.mean((G-C)**2)
        #Style loss(Gram matrix)
        gram_gen = G.reshape(channel, height*width).mm(G.reshape(channel, height*width).t())
        
        gram_style = S.reshape(channel, height*width).mm(S.reshape(channel, height*width).t())
    
        style_loss += torch.mean((gram_gen - gram_style)**2)
        
    total_loss = alpha*original_loss + beta*style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if epoch%200 == 0:
        print(total_loss)
        save_image(generated_img, "Generated_Image"+str(epoch/200)+".jpg")