In [1]:
import torch
import warnings
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
from zipfile import ZipFile
from PIL import Image
import yaml
from pathlib import Path


with open("cfg.yaml", "r") as file:
        cfg = yaml.safe_load(file)

sys.path.append("src/")


torch.manual_seed(1)
warnings.filterwarnings("ignore")

print(f"cuda is available: {torch.cuda.is_available()}")
print(f"mps is available: {torch.backends.mps.is_available()}")

from net import VGG_Encoder, Decoder, Net
from train import train
from inference_generation import test_transform, style_transfer


cuda is available: False
mps is available: True


In [2]:
encoder = VGG_Encoder()
decoder = Decoder()

In [3]:
print(encoder)
print(decoder)

VGG_Encoder(
  (relu1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
  )
  (relu2): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
  )
  (relu3): Sequential(
    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
  )
  (relu4): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(

In [4]:
net = Net(encoder, decoder)
print(net)

Net(
  (enc_1): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
  )
  (enc_2): Sequential(
    (0): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU(inplace=True)
    )
  )
  (enc_3): Sequential(
    (0): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU(inplace=True)
    )
  )
  (enc_4): Sequential(
    (0): Sequential(
      (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

In [5]:
grayscale_image=torch.rand(1,3,256,256)
style_reference_image=torch.rand(1,3,256,256)

output = net(grayscale_image, style_reference_image)
print(output[0].shape)

torch.Size([1, 3, 256, 256])


In [6]:
train(net,cfg)

Using device: mps

 60%|██████    | 603/1000 [12:05<08:21,  1.26s/it]

Test

In [None]:
# This are the image and the style I want to mix
# --> Set their path in cfg.yaml
input_img=Image.open("data/content_dir/brad_pitt.jpg")
display(input_img)
style_ref=Image.open("data/style_dir/brushstrokes.jpg")
display(style_ref)

device = torch.device("cuda" if torch.cuda.is_available()
                else "mps" if torch.backends.mps.is_built() else "cpu")

output_dir = Path(cfg["output_dir"])
output_dir.mkdir(exist_ok=True, parents=True)

# Either --content or --contentDir should be given.
assert (cfg["content"] or cfg["content_dir"])
if cfg["content"]:
    content_paths = [Path(cfg["content"])]
else:
    content_dir = Path(cfg["content_dir"])
    content_paths = [f for f in content_dir.glob('*')]

# Either --style or --styleDir should be given.
assert (cfg["style"] or cfg["style_dir"])
if cfg["style"]:
    style_paths = cfg["style"].split(',')
    if len(style_paths) == 1:
        style_paths = [Path(cfg["style"])]
    else:
        do_interpolation = True
        assert (cfg["style_interpolation_weights"] != ''), \
            'Please specify interpolation weights'
        weights = [int(i) for i in cfg["style_interpolation_weights"].split(',')]
        interpolation_weights = [w / sum(weights) for w in weights]
else:
    style_dir = Path(cfg["style_dir"])
    style_paths = [f for f in style_dir.glob('*')]

encoder = VGG_Encoder()
decoder = Decoder()

decoder.eval()
encoder.eval()

decoder.load_state_dict(torch.load(cfg["decoder"]))
# encoder.load_state_dict(torch.load(args.vgg))
#vgg = nn.Sequential(*list(vgg.children())[:31])

encoder.to(device)
decoder.to(device)

content_tf = test_transform(cfg["content_size"], cfg["crop"])
style_tf = test_transform(cfg["style_size"], cfg["crop"])

do_interpolation = False

for content_path in content_paths:
    # OFF
    if do_interpolation:  # one content image, N style image
        style = torch.stack([style_tf(Image.open(str(p))) for p in style_paths])
        content = content_tf(Image.open(str(content_path))) \
            .unsqueeze(0).expand_as(style)
        style = style.to(device)
        content = content.to(device)
        with torch.no_grad():
            output = style_transfer(encoder, decoder, content, style,
                                    cfg["alpha"], interpolation_weights)
        output = output.cpu()
        output_name = output_dir / '{:s}_interpolation{:s}'.format(
            content_path.stem, cfg["save_ext"])
        torchvision.utils.save_image(output, str(output_name))

    # ON
    else:  # process one content and one style
        for style_path in style_paths:
            content = content_tf(Image.open(str(content_path)))
            style = style_tf(Image.open(str(style_path)))
            #if cfg["preserve_color"]:
            #    style = coral(style, content)
            style = style.to(device).unsqueeze(0)
            content = content.to(device).unsqueeze(0)
            with torch.no_grad():
                output = style_transfer(encoder, decoder, content, style,
                                        cfg["alpha"])
            output = output.cpu()
            display(output)

            output_name = output_dir / '{:s}_stylized_{:s}{:s}'.format(
                content_path.stem, style_path.stem, cfg["save_ext"])
            torchvision.utils.save_image(output, str(output_name))
            

tensor([[[[ 0.0432,  0.0433,  0.0431,  ...,  0.0431,  0.0431,  0.0432],
          [ 0.0434,  0.0435,  0.0433,  ...,  0.0432,  0.0432,  0.0432],
          [ 0.0432,  0.0433,  0.0432,  ...,  0.0432,  0.0432,  0.0433],
          ...,
          [ 0.0428,  0.0427,  0.0426,  ...,  0.0434,  0.0432,  0.0438],
          [ 0.0428,  0.0426,  0.0425,  ...,  0.0437,  0.0434,  0.0440],
          [ 0.0427,  0.0426,  0.0425,  ...,  0.0435,  0.0430,  0.0437]],

         [[ 0.0286,  0.0285,  0.0286,  ...,  0.0283,  0.0285,  0.0284],
          [ 0.0287,  0.0286,  0.0287,  ...,  0.0284,  0.0285,  0.0285],
          [ 0.0285,  0.0284,  0.0286,  ...,  0.0284,  0.0285,  0.0285],
          ...,
          [ 0.0286,  0.0284,  0.0283,  ...,  0.0268,  0.0269,  0.0270],
          [ 0.0285,  0.0284,  0.0283,  ...,  0.0272,  0.0272,  0.0274],
          [ 0.0285,  0.0284,  0.0284,  ...,  0.0272,  0.0274,  0.0273]],

         [[-0.0103, -0.0103, -0.0103,  ..., -0.0099, -0.0099, -0.0099],
          [-0.0100, -0.0101, -