<a href="https://colab.research.google.com/github/Aryamaan777/Neural-Style-Transfer/blob/main/NeuralStyleTransfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image

In [None]:
class VGG(nn.Module):
  def __init__(self):
    super(VGG,self).__init__()
    self.chosen_features=["0","5","10","19","28"]
    self.model=models.vgg19(pretrained=True).features[:29]

  def forward(self,x):
    features=[]
    for layer_num,layer in enumerate(self.model):
      x=layer(x)
      if (str(layer_num) in self.chosen_features):
        features.append(x)
    return features

In [None]:
def load_image(image_name):
  image=Image.open(image_name)
  image=loader(image).unsqueeze(0)
  return image.to(device)

In [None]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
imsize=356

In [None]:
loader=transforms.Compose(
    [
     transforms.Resize((imsize,imsize)),
     transforms.ToTensor()
    ]
)

In [None]:
original_img=load_image("messi.jpg")
style_img=load_image("lightning.jpg")

In [None]:
generated=original_img.clone().requires_grad_(True)
model=VGG().to(device).eval()

In [None]:
total_steps=6000
learning_rate=0.001
alpha=1
beta=0.01
optimizer=optim.Adam([generated],lr=learning_rate)

In [None]:
for step in range(total_steps):
  generated_features=model(generated)
  original_img_features=model(original_img)
  style_features=model(style_img)

  style_loss=original_loss=0
  
  for gen_feature, orig_feature, style_feature in zip(generated_features,original_img_features,style_features):
    batch_size,channel,height,width=gen_feature.shape
    original_loss += torch.mean((gen_feature-orig_feature)**2)
    G=gen_feature.view(channel,height*width).mm(
        gen_feature.view(channel,height*width).t()
    )

    A=style_feature.view(channel,height*width).mm(
        style_feature.view(channel,height*width).t()
    )

    style_loss += torch.mean((G-A)**2)
  
  total_loss=alpha*original_loss+beta*style_loss
  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()

  if(step%200==0):
    print(total_loss,step)
    save_image(generated,"generated.jpg")

tensor(706261.7500, device='cuda:0', grad_fn=<AddBackward0>) 0
tensor(224398.7344, device='cuda:0', grad_fn=<AddBackward0>) 200
tensor(20543.6914, device='cuda:0', grad_fn=<AddBackward0>) 400
tensor(8236.5371, device='cuda:0', grad_fn=<AddBackward0>) 600
tensor(6677.7939, device='cuda:0', grad_fn=<AddBackward0>) 800
tensor(5710.4585, device='cuda:0', grad_fn=<AddBackward0>) 1000
tensor(4992.0571, device='cuda:0', grad_fn=<AddBackward0>) 1200
tensor(4414.6987, device='cuda:0', grad_fn=<AddBackward0>) 1400
tensor(3932.2847, device='cuda:0', grad_fn=<AddBackward0>) 1600
tensor(3517.6711, device='cuda:0', grad_fn=<AddBackward0>) 1800
tensor(3155.9307, device='cuda:0', grad_fn=<AddBackward0>) 2000
tensor(2832.4741, device='cuda:0', grad_fn=<AddBackward0>) 2200
tensor(2540.7439, device='cuda:0', grad_fn=<AddBackward0>) 2400
tensor(2275.9248, device='cuda:0', grad_fn=<AddBackward0>) 2600
tensor(2035.5204, device='cuda:0', grad_fn=<AddBackward0>) 2800
tensor(1817.6431, device='cuda:0', grad_fn