In [21]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image

In [22]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [23]:
class vggFeatureClass(nn.Module):
  def __init__(self):
    super(vggFeatureClass, self).__init__()

    model = models.vgg16(pretrained=True)
    self.modelLayers = model.features[:25].to(device)

    for param in self.modelLayers.parameters():
      param.requires_grad = False

  def forward(self, x):
    feature = []
    for ind, layer in enumerate(self.modelLayers):
      x = layer(x)

      if ind in {0,5,10,17,24}:
        feature.append(x)

    return feature

In [24]:
transform = transforms.Compose([
    transforms.Resize(334),
    transforms.ToTensor(),
])

In [25]:
image = Image.open('image.png')
image = transform(image)
image = image.to(device)

style = Image.open('style.jpg')
style = transform(style)
style = style.to(device)


In [30]:
generatedImg = image.clone().detach().to(device).requires_grad_(True)

In [31]:
no_times = 4000
alpha = 1
beta = 0.001
learning_rate = 0.001

In [32]:
optim = torch.optim.Adam([generatedImg], lr=learning_rate)
vggFeature = vggFeatureClass().to(device)

In [33]:
for n in range(no_times+1):
  imageFeature = vggFeature(image)
  styleFeature = vggFeature(style)
  generatedImgFeature = vggFeature(generatedImg)

  loss = 0
  for i in range(len(imageFeature)):
    shape = imageFeature[i].shape
    imageFeature_reshaped = imageFeature[i].reshape(shape[0], -1)
    styleFeature_reshaped = styleFeature[i].reshape(shape[0], -1)
    generatedImgFeature_reshaped = generatedImgFeature[i].reshape(shape[0], -1)

    orgLoss = torch.mean((imageFeature_reshaped - generatedImgFeature_reshaped)**2)

    gramS = styleFeature_reshaped @ styleFeature_reshaped.t()
    gramG = generatedImgFeature_reshaped @ generatedImgFeature_reshaped.t()
    styleLoss = torch.mean((gramS - gramG)**2)

    loss += alpha*orgLoss + beta*styleLoss

  loss.backward()
  optim.step()
  optim.zero_grad()
  avgLoss = loss.item()
  if n%50 == 0:
    print("In ",n, " Loss: ",avgLoss)
    img_pil = transforms.ToPILImage()(generatedImg.clamp(0, 1))
    if n%1000 == 0:
      img_pil.save(f"generatedImg{n}.png")


In  0  Loss:  146318.140625
In  50  Loss:  43227.2734375
In  100  Loss:  23550.025390625
In  150  Loss:  17028.58984375
In  200  Loss:  13578.1787109375
In  250  Loss:  11376.1884765625
In  300  Loss:  9829.9892578125
In  350  Loss:  8675.271484375
In  400  Loss:  7770.05810546875
In  450  Loss:  7034.27490234375
In  500  Loss:  6418.9990234375
In  550  Loss:  5891.8232421875
In  600  Loss:  5434.2236328125
In  650  Loss:  5032.1943359375
In  700  Loss:  4675.77490234375
In  750  Loss:  4357.7099609375
In  800  Loss:  4072.177734375
In  850  Loss:  3813.872314453125
In  900  Loss:  3579.263671875
In  950  Loss:  3365.90234375
In  1000  Loss:  3171.191650390625
In  1050  Loss:  2992.74169921875
In  1100  Loss:  2828.75
In  1150  Loss:  2677.664306640625
In  1200  Loss:  2538.179931640625
In  1250  Loss:  2409.03662109375
In  1300  Loss:  2288.927734375
In  1350  Loss:  2176.66650390625
In  1400  Loss:  2071.77392578125
In  1450  Loss:  1973.8218994140625
In  1500  Loss:  1882.3409423828