In [1]:
##Arbitrary Style Transfer Code

Importing all the modules

In [2]:
import numpy as np
from os import listdir, mkdir, sep
from os.path import join, exists, splitext
import cv2

In [3]:
# from google.colab import drive
# drive.mount('/content/drive')

# New Section

In [4]:
# !wget http://images.cocodataset.org/zips/test2014.zip
# !unzip *.zip
# !rm *.zip

In [5]:
import torch
import torch.nn as nn
from torchvision import models, transforms
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

In [6]:
model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg19', pretrained=True)

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.6.0


In [7]:
!mkdir -p test2014/1
!mv test2014/*.jpg test2014/1/

mv: cannot stat 'test2014/*.jpg': No such file or directory


# New Section

In [8]:
enc=nn.Sequential(*list(model.features.children())[:21])
for i in enc.parameters():
  i.requires_grad = False
enc

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [9]:
# Use Sequential to define decoder [Just reverse of vgg with pooling replaced by nearest neigbour upscaling]
dec = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect' ),
nn.ReLU(),
nn.Upsample(scale_factor=2,mode='nearest'),
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.ReLU(),
nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.ReLU(),
nn.Upsample(scale_factor=2,mode='nearest'),
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.ReLU(),
nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.ReLU(),
nn.Upsample(scale_factor=2,mode='nearest'),
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.ReLU(),
nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.ReLU()
)

In [10]:
dec

Sequential(
  (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
  (1): ReLU()
  (2): Upsample(scale_factor=2.0, mode=nearest)
  (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
  (4): ReLU()
  (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
  (6): ReLU()
  (7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
  (8): ReLU()
  (9): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
  (10): ReLU()
  (11): Upsample(scale_factor=2.0, mode=nearest)
  (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
  (13): ReLU()
  (14): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
  (15): ReLU()
  (16): Upsample(scale_factor=2.0, mode=nearest)
  (17): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), pad

In [11]:
# pipeline

In [12]:
class Encoder_Decoder(nn.Module):
  def __init__(self, encoder, decoder):
    super().__init__()

    self.encoder = encoder
    self.decoder = decoder
    self.style_features = []
    self.style_layers = [1, 6, 11, 20] # relu1_1, relu2_1, relu3_1, relu4_1
    for i in self.style_layers:
      self.encoder._modules[str(i)].register_forward_hook(self.style_feature_hook)

  def style_feature_hook(self, module, input, output):
    self.style_features.append(output)

  def forward(self, image):

    self.content_in = self.encoder(image)
    self.style_features = []

    return self.decoder(self.content_in)

In [13]:
def read_image(path, size=None, gray=False):
    img = cv2.imread(path)
    if gray:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    else:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    if size != None:
        img = cv2.resize(img, size)
    return img

In [14]:
from PIL import Image
# lion = Image.open('drive/MyDrive/images/style/1/lion.jpg')
# lion = Image.open('https://external-content.duckduckgo.com/iu/?u=https%3A%2F%2Fs-media-cache-ak0.pinimg.com%2Foriginals%2Fad%2F21%2F86%2Fad2186c3301997a31780d5ab600d71c9.jpg&f=1&nofb=1')

In [15]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.Lambda(lambda x: x.to(device))
])

In [16]:
enc = enc.to(device)
dec = dec.to(device)

In [17]:
# train = DataLoader(ImageFolder('drive/MyDrive/images/content', transform=preprocess), batch_size=1, shuffle=True, num_workers=0)
# test = DataLoader(ImageFolder('drive/MyDrive/images/style', transform=preprocess), batch_size=1, shuffle=True, num_workers=0)

In [18]:
# x = 0
# for i in train:
#   save_image(i[0], f"{x}.jpg")
#   plt.imshow(Image.open(f"{x}.jpg"))
#   plt.show()

In [19]:
# x = 0
# for i in test:
#   save_image(i[0], f"{x}.jpg")
#   plt.imshow(Image.open(f"{x}.jpg"))
#   plt.show()

In [20]:
def trainModel(Model, Loss, trainImagesPath, testImagesPath, batchSize=1, epochs=5, learningRate=6e-4, weightDecay=1e-3, alpha=0.1):

  train_images = DataLoader(ImageFolder(trainImagesPath, transform=preprocess), batch_size=batchSize, shuffle=True, num_workers=0)
  # test_images = DataLoader(ImageFolder(testImagesPath, transform=preprocess), batch_size=batchSize, shuffle=True, num_workers=0)

  model = Model(enc, dec).to(device)
  loss_fn = Loss().to(device)

  optimizer = torch.optim.Adam(model.parameters(), lr=learningRate, weight_decay=weightDecay)

  for e in range(epochs):
    epoch_loss = 0.0
    i=0.0
    for img in train_images:
      img = img[0]
      i+=1
      if(i > 100): 
        break
      print("\r", i, end="")
      optimizer.zero_grad()

      decoded = model(img)
    
      content_in = model.encoder(img)
      content_out = model.encoder(decoded)
      styles_in = model.style_features[:4]
      styles_out = model.style_features[4:]

      loss = loss_fn(content_in, content_out, styles_in, styles_out)
      loss.backward()
      optimizer.step()
      epoch_loss += loss.item()
      del content_in, content_out, styles_in, styles_out, decoded, loss
      # save_image(decoded, f"tmp_{i}.jpg")
    print("\nEpoch", e, end=": ")
    print("Epoch Loss = ", epoch_loss/len(train_images))

  
  torch.save(model.state_dict(), 'enc_dec_model')
  return model



In [21]:
class ContentStyleLoss(nn.Module):
  def __init__(self, lam=0.5):
    super().__init__()
    self.lam = lam

  def forward (self, content_in, content_out, styles_in, styles_out):
    contentLoss = torch.norm(content_out - content_in)
    styleLoss = np.sum([
                           torch.linalg.norm(torch.mean(styles_out[i], (2, 3)) - torch.mean(styles_in[i], (2,3))) + 
                           torch.linalg.norm(torch.std(styles_out[i], axis=(2, 3), unbiased=False) - torch.std(styles_in[i], axis=(2, 3), unbiased=False)) 
                           for i in range(len(styles_in))
    ])

    return contentLoss + self.lam*styleLoss

In [22]:
md = trainModel(Encoder_Decoder, ContentStyleLoss, 'test2014', 'drive/MyDrive/images/style', epochs=100)

 100.0
Epoch 0: Epoch Loss =  7.406206860007281
 100.0
Epoch 1: Epoch Loss =  6.699786045394888
 100.0
Epoch 2: Epoch Loss =  6.268280286011266
 100.0
Epoch 3: Epoch Loss =  6.094792484959381
 100.0
Epoch 4: Epoch Loss =  6.0685850681138875
 100.0
Epoch 5: Epoch Loss =  6.011116146272418
 100.0
Epoch 6: Epoch Loss =  5.711828028002376
 100.0
Epoch 7: Epoch Loss =  5.739946423781422
 100.0
Epoch 8: Epoch Loss =  5.576249591352602
 100.0
Epoch 9: Epoch Loss =  5.470968141070471
 100.0
Epoch 10: Epoch Loss =  5.517062010820624
 100.0
Epoch 11: Epoch Loss =  5.477181994486703
 100.0
Epoch 12: Epoch Loss =  5.405483446337561
 100.0
Epoch 13: Epoch Loss =  5.533116719276901
 100.0
Epoch 14: Epoch Loss =  5.448136079882932
 100.0
Epoch 15: Epoch Loss =  5.429241176209955
 100.0
Epoch 16: Epoch Loss =  5.326064327387339
 100.0
Epoch 17: Epoch Loss =  5.196048619160982
 100.0
Epoch 18: Epoch Loss =  5.278502362670524
 100.0
Epoch 19: Epoch Loss =  5.077068604563439
 100.0
Epoch 20: Epoch Loss =

In [31]:
preprocess_test = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.to(device)),
    # transforms.Lambda(lambda x: x.unsqueeze(0))
])
lion = Image.open('drive/MyDrive/images/style/1/lion.jpg')
model_in = preprocess_test(lion)
model = Encoder_Decoder(enc, dec).to(device)
model.load_state_dict(torch.load('enc_dec_model'))

<All keys matched successfully>

In [25]:
op = model(model_in)
save_image(model_in, 'test_in.jpg')
save_image(op, 'test_out.jpg')

In [26]:
loss = ContentStyleLoss()

In [27]:
content_in = model.encoder(model_in)
content_out = model.encoder(op)
styles_in = model.style_features[:4]
styles_out = model.style_features[4:]
l = loss(content_in, content_out, styles_in, styles_out)

In [28]:
print(l.item())

1310.5318603515625


In [33]:
test_images = DataLoader(ImageFolder('drive/MyDrive/images/content', transform=preprocess_test), batch_size=1, shuffle=True, num_workers=0)
x = 1
for i in test_images:
  i = i[0]
  # print(i.shape)
  decoded = model(i)
  content_in = model.encoder(i)
  content_out = model.encoder(decoded)
  styles_in = model.style_features[:4]
  styles_out = model.style_features[4:]

  l = loss(content_in, content_out, styles_in, styles_out)

  save_image(i, f'in_{x}.jpg')
  save_image(decoded, f'out{x}.jpg')
  x += 1
  del content_in, content_out, styles_in, styles_out
