In [47]:
%reset

Once deleted, variables cannot be recovered. Proceed (y/[n])? y


In [48]:
import torch
import torchvision.models as models
import keras
from torchvision.utils import save_image

In [49]:
base_image_path = keras.utils.get_file("/content/paris.jpg", "https://i.imgur.com/F28w3Ac.jpg")
style_reference_image_path = keras.utils.get_file(
    "/content/starry_night.jpg", "https://i.imgur.com/9ooB60I.jpg"
)

In [50]:
models.vgg19(pretrained=True).features



Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [112]:
import torch.nn as nn

class VGG(nn.Module):
  def __init__(self):
    super(VGG, self).__init__()

    self.model = models.vgg19(pretrained=True).features[:29]
    self.chosen_features = ["0", "5", "10", "19" ,"28"]

  def forward(self, x):
    features=[]

    # for layer_num, layer in enumerate(self.model):
    #   if str(layer_num) in self.chosen_features:
    #     features.append(layer(x))
    # return features

    for layer_num, layer in enumerate(self.model):
      x=layer(x)
      if str(layer_num) in self.chosen_features:
        features.append(x)
    return features



In [113]:
from PIL import Image
from torchvision import transforms

img_dim = 356

loader = transforms.Compose([
    transforms.Resize((img_dim, img_dim)),
    transforms.ToTensor()
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_image(img_path):
  img = Image.open(img_path)
  img = loader(img).unsqueeze(0)
  return img.to(device)


In [114]:
original_img = load_image("/content/paris.jpg")
style_img = load_image("/content/starry_night.jpg")

In [115]:
generated = original_img.clone().requires_grad_(True)

In [116]:
total_steps = 6000
lr = 0.001
alpha = 1
beta = 0.01

In [117]:
optimizer = torch.optim.Adam([generated], lr=lr)

In [122]:
model = VGG().to(device)
for param in model.parameters():
  param.requires_grad = False



In [123]:
for step in range(total_steps):
  generated_features = model(generated)
  original_img_features = model(original_img)
  style_features = model(style_img)

  style_loss=0
  content_loss = 0

  for gen_feature, orig_feature, style_feature in zip(generated_features, original_img_features, style_features):
    batch_size, channel, height, width = gen_feature.shape
    content_loss += torch.mean((orig_feature - gen_feature)**2)

    gen_gram = gen_feature.view(channel, height*width).mm(
        gen_feature.view(channel, height*width).t()
    )

    style_gram = style_feature.view(channel, height*width).mm(
        style_feature.view(channel, height*width).t()
    )

    style_loss += torch.mean((style_gram - gen_gram) ** 2)

  total_loss = (beta*style_loss) + (alpha*content_loss)
  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()

  if step%200==0:
    print(f"Loss at step-{step} = {total_loss.item()}")
    save_image(generated, f"/content/save-{step}.jpg")


Loss at step-0 = 667964.125
Loss at step-200 = 30273.453125
Loss at step-400 = 12338.76953125
Loss at step-600 = 5605.4306640625
Loss at step-800 = 3428.556640625
Loss at step-1000 = 2662.79150390625
Loss at step-1200 = 2266.47802734375
Loss at step-1400 = 1998.0616455078125
Loss at step-1600 = 1793.9908447265625
Loss at step-1800 = 1631.442138671875
Loss at step-2000 = 1495.5675048828125
Loss at step-2200 = 1379.499755859375
Loss at step-2400 = 1277.97265625
Loss at step-2600 = 1188.7489013671875
Loss at step-2800 = 1110.1407470703125
Loss at step-3000 = 1040.2374267578125
Loss at step-3200 = 977.2183227539062
Loss at step-3400 = 920.3029174804688
Loss at step-3600 = 869.2216186523438
Loss at step-3800 = 823.2390747070312
Loss at step-4000 = 781.6322021484375
Loss at step-4200 = 744.1322631835938
Loss at step-4400 = 709.6589965820312
Loss at step-4600 = 679.482177734375
Loss at step-4800 = 651.6325073242188
Loss at step-5000 = 627.5870971679688
Loss at step-5200 = 605.4299926757812
Lo

In [125]:
save_image(generated, "/content/save-6000.jpg")
