In [1]:
from torchvision.models import vgg19

In [2]:
vgg19_model = vgg19(pretrained=True)



In [3]:
vgg19_model.features[:10]

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

In [4]:
conv = {
    'conv1_1' : 0,
    'conv2_1' : 5,
    'conv3_1' : 10,
    'conv4_1' : 19,
    'conv4_2' : 21,
    'conv5_1' : 28
}

In [5]:
vgg19_model.features[:conv['conv3_1']+1]

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [6]:
# 직접 FeatureMap을 뽑아내기

import torch

# 배치, 채널, H, W
x = torch.ones(1, 3, 256, 256)
y = vgg19_model.features[:conv['conv3_1']+1](x)
y.shape # 피쳐맵 추출

torch.Size([1, 256, 64, 64])

# 모델 실험

In [7]:
from models import StyleTransfer

In [8]:
test_model = StyleTransfer()

In [9]:
import torch

# 배치, 채널, H, W
x = torch.ones(1, 3, 256, 256)
y = test_model(x,'style')

In [10]:
y[1].shape

torch.Size([1, 128, 128, 128])

# Loss 실험

In [11]:
from loss import StyleLoss

In [12]:
style_loss = StyleLoss()
import torch
feature_map = torch.ones(1,3,256,256)
G = style_loss.gram_matrix(feature_map)

In [13]:
G.shape

torch.Size([1, 3, 3])

In [15]:
x = torch.ones(1,3,256,256)
y = torch.zeros(1,3,256,256)

style_loss(x,y)

tensor(0.1111)

# Data Pre,Post Processing 테스트

In [None]:
import torch
from PIL import Image
from train_final import pre_processing, post_processing

In [23]:
image = Image.open('./content.jpg')
image

# image -> tensor
image_tensor = pre_processing(image)
image_tensor.shape

# tensor -> image
image_pil = post_processing(image_tensor)
image_pil.size

(512, 512)