In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data 
from PIL import Image
import torchvision.transforms as transformers
from torchvision.utils import save_image
import torch.optim as optim
import torchvision.models as models


In [None]:
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
device

In [None]:
model = models.vgg19(weights='VGG19_Weights.DEFAULT').features
print(model)

In [None]:

class VGG(nn.Module):
    def __init__(self):
        super(VGG,self).__init__()
        
        self.layers = ['0','5','10','19','28']
        self.model = models.vgg19(weights='VGG19_Weights.DEFAULT').features[:29]
        
    def forward(self,x):
        features = []
        
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            if str(layer_num) in self.layers:
                features.append(x)
                
        return features
                        



In [None]:
loader = transformers.Compose(
    [
        transformers.Resize((512,521)),
        transformers.ToTensor()
    ]
)

def load_image(image_path):
    image = Image.open(image_path)
    image = loader(image).unsqueeze(0)
    return image.to(device)

In [None]:
content_image = load_image("content.jpg")
style_image = load_image("style.jpg")

In [None]:
content_image.shape

In [None]:
model = VGG().to(device=device)

In [None]:
generated_image = content_image.clone().requires_grad_(True)

In [None]:
TOTAL_STEPS = 400
LEARNING_RATE = 0.001
ALPHA = 1
BETA = 0.01


In [None]:
optimizer = optim.Adam([generated_image],lr=LEARNING_RATE)

In [None]:
for step in range(TOTAL_STEPS):
    torch.autograd.set_detect_anomaly(True)
    
    generated_features = model(generated_image)
    content_features = model(content_image)
    style_features = model(style_image)
    
    style_loss = 0
    content_loss = 0
    
    for generated_feature, content_feature,style_feature in zip(generated_features,content_features,style_features):
        batch_size, channels, height, width = generated_feature.shape
        
        # Calculating the content loss 
        content_loss += torch.mean((generated_feature - content_feature) ** 2)
        
        # Calculating the Gram Matrix for the generated and style image
        G = generated_feature.view(channels,height*width).mm(
            generated_feature.view(channels,height*width).t()
        )
        
        S = style_feature.view(channels,height*width).mm(
            style_feature.view(channels,height*width).t()
        )
        
        style_loss += torch.mean((G - S) ** 2)
        
    total_loss = ALPHA * content_loss + BETA * style_loss
        
    optimizer.zero_grad()
    total_loss.backward(retain_graph=True)
    optimizer.step()
        
        
    if step % 50 == 0:
        print(f"Step : {step} Total_Loss : {total_loss} Content_Loss = {content_loss} Style_Loss : {style_loss}")
    if step % 200 == 0:
        save_image(generated_image,"generate_image.jpg")
        
        