In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms#img to tensor
import torchvision.models as models#vgg19
from torchvision.utils import save_image

In [None]:
model=models.vgg19(pretrained=True).features
print(model)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:06<00:00, 84.0MB/s]


Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [None]:
#needed layers=0,5,10,19,28
class VGG(nn.Module):
  def __init__(self):
    super(VGG, self).__init__()

    self.chosen_features = ['0', '5', '10', '19', '28']
    self.model = models.vgg19(pretrained=True).features[:29]

  def forward(self, x):
    features = []
    for layer_num, layer in enumerate(self.model):
      x = layer(x)
      if str(layer_num) in self.chosen_features:
        features.append(x)
    return features


In [None]:
def load_image(image_name):
  image=Image.open(image_name)
  image=loader(image).unsqueeze(0)
  return image

#device=torch.device("cude" if torch.cuda is_available else "cpu")
image_size=400

loader=transforms.Compose(
    [
        transforms.Resize([image_size,image_size]),
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    ]
)

original_img=load_image("/content/1.jpeg")
style_img=load_image("/content/2.jpg")

In [None]:
#3 inputs= style image,contentimg,random noise as generated image and then optmize the random image
#generated=torch.randn(original_image.shape,requires_grad=True)
#started with a clone of org image instead of random noise gives better results
#its faster
model=VGG().eval()
generated=original_img.clone().requires_grad_(True)

#hp
total_steps=6000
lr=0.001
alpha=1
beta=0.01

optimizer=optim.AdamW([generated],lr=lr)

for step in range(total_steps):
  #1.send the 3 images to vgg network
  generated_features=model(generated)
  original_img_features=model(original_img)
  style_features=model(style_img)

  style_loss=original_loss=0

  for gen_feature,orig_feature,style_feature in zip(
      generated_features,original_img_features,style_features
  ):
      batch_size,channel,height,width=gen_feature.shape
      original_loss+=torch.mean((gen_feature-orig_feature)**2)

      #gram matrix
      #for generated and the style
      G=gen_feature.view(channel,height*width).mm(
          gen_feature.view(channel,height*width).t()
      )

      A=style_feature.view(channel,height*width).mm(
          style_feature.view(channel,height*width).t()
      )

      style_loss+=torch.mean((G-A)**2)

  total_loss=alpha*original_loss +beta*style_loss
  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()

  if step % 100 == 0:
    print(f"Step {step}: Total Loss: {total_loss.item()}")
    save_image(generated, f"generated_{step}.jpeg")




Step 0: Total Loss: 1091594.75
Step 100: Total Loss: 190620.921875
Step 200: Total Loss: 94442.6484375
Step 300: Total Loss: 58888.6328125
Step 400: Total Loss: 41855.296875
Step 500: Total Loss: 31031.19140625
