In [1]:
import torch
import os
import time
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
from style_decorator import StyleDecorator
import torchvision.transforms as transforms
from PIL import Image

In [2]:
class AvatarNet(nn.Module):
    def __init__(self, layers=[1, 6, 11, 20]):
        super(AvatarNet, self).__init__()
        self.encoder = Encoder(layers)
        self.decoder = Decoder(layers)
        self.adain = AdaIN()
        self.decorator = StyleDecorator()

    def forward(self, content, styles, style_strength=1.0, patch_size=3, patch_stride=1, interpolation_weights=None):
        if interpolation_weights is None:
            interpolation_weights = [1/len(styles)] * len(styles)
        content_feature = self.encoder(content)
        style_features = []
        for style in styles:
            style_features.append(self.encoder(style))
        transformed_feature = []
        for style_feature, interpolation_weight in zip(style_features, interpolation_weights):
            transformed_feature.append(self.decorator(content_feature[-1], style_feature[-1], style_strength, patch_size, patch_stride) * interpolation_weight)
        transformed_feature = sum(transformed_feature)
        style_features = [style_feature[:-1][::-1] for style_feature in style_features]
        stylized_image = self.decoder(transformed_feature, style_features, interpolation_weights)
        return stylized_image

class AdaIN(nn.Module):
    def __init__(self):
        super(AdaIN, self).__init__()
    
    def forward(self, content, style, style_strength=1.0, eps=1e-5):
        b, c, h, w = content.size()
        content_std, content_mean = torch.std_mean(content.view(b, c, -1), dim=2, keepdim=True)
        style_std, style_mean = torch.std_mean(style.view(b, c, -1), dim=2, keepdim=True)
        normalized_content = (content.view(b, c, -1) - content_mean)/(content_std+eps)
        stylized_content = (normalized_content * style_std) + style_mean
        output = (1-style_strength)*content + style_strength*stylized_content.view(b, c, h, w)
        return output
    
class Encoder(nn.Module):
    def __init__(self,  layers=[1, 6, 11, 20]):        
        super(Encoder, self).__init__()
        vgg = torchvision.models.vgg19(pretrained=True).features
        self.encoder = nn.ModuleList()
        temp_seq = nn.Sequential()
        for i in range(max(layers)+1):
            temp_seq.add_module(str(i), vgg[i])
            if i in layers:
                self.encoder.append(temp_seq)
                temp_seq = nn.Sequential()
        
    def forward(self, x):
        features = []
        for layer in self.encoder:
            x = layer(x)
            features.append(x)
        return features


class Decoder(nn.Module):
    def __init__(self, layers=[1, 6, 11, 20], transformers=[AdaIN(), AdaIN(), AdaIN(), None]):
        super(Decoder, self).__init__()
        vgg = torchvision.models.vgg19(pretrained=False).features
        self.transformers = transformers
        self.decoder = nn.ModuleList()
        temp_seq  = nn.Sequential()
        count = 0
        for i in range(max(layers)-1, -1, -1):
            if isinstance(vgg[i], nn.Conv2d):
                out_channels = vgg[i].in_channels
                in_channels = vgg[i].out_channels
                kernel_size = vgg[i].kernel_size
                temp_seq.add_module(str(count), nn.ReflectionPad2d(padding=(1,1,1,1)))
                count += 1
                temp_seq.add_module(str(count), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size))
                count += 1
                temp_seq.add_module(str(count), nn.ReLU())
                count += 1
            elif isinstance(vgg[i], nn.MaxPool2d):
                temp_seq.add_module(str(count), nn.Upsample(scale_factor=2))
                count += 1
            if i in layers:
                self.decoder.append(temp_seq)
                temp_seq  = nn.Sequential()
        self.decoder.append(temp_seq[:-1])    
        
    def forward(self, x, styles, interpolation_weights=None):
        if interpolation_weights is None:
            interpolation_weights = [1/len(styles)] * len(styles)
        y = x
        for i, layer in enumerate(self.decoder):
            y = layer(y)
            if self.transformers[i]:
                transformed_feature = []
                for style, interpolation_weight in zip(styles, interpolation_weights):
                    transformed_feature.append(self.transformers[i](y, style[i]) * interpolation_weight)
                y = sum(transformed_feature)
        return y

In [3]:
def lastest_arverage_value(values, length=100):
    if len(values) < length:
        length = len(values)
    return sum(values[-length:])/length

class ImageFolder(torch.utils.data.Dataset):
    def __init__(self, root_path, imsize=None, cropsize=None, cencrop=False):
        super(ImageFolder, self).__init__()

        self.file_names = sorted(os.listdir(root_path))
        self.root_path = root_path
        self.transform = _transformer(imsize, cropsize, cencrop)

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, index):
        image = Image.open(os.path.join(self.root_path + self.file_names[index])).convert("RGB")
        return self.transform(image)

In [4]:
def imsave(tensor, path):
    denormalize = _normalizer(denormalize=True)
    if tensor.is_cuda:
        tensor = tensor.cpu()
    tensor = torchvision.utils.make_grid(tensor)
    torchvision.utils.save_image(denormalize(tensor).clamp_(0.0, 1.0), path)
    return None

In [5]:
class Decoder(nn.Module):
    def __init__(self, layers=[1, 6, 11, 20], transformers=[AdaIN(), AdaIN(), AdaIN(), None]):
        super(Decoder, self).__init__()
        vgg = torchvision.models.vgg19(pretrained=False).features
        self.transformers = transformers
        self.decoder = nn.ModuleList()
        temp_seq  = nn.Sequential()
        count = 0
        for i in range(max(layers)-1, -1, -1):
            if isinstance(vgg[i], nn.Conv2d):
                out_channels = vgg[i].in_channels
                in_channels = vgg[i].out_channels
                kernel_size = vgg[i].kernel_size
                temp_seq.add_module(str(count), nn.ReflectionPad2d(padding=(1,1,1,1)))
                count += 1
                temp_seq.add_module(str(count), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size))
                count += 1
                temp_seq.add_module(str(count), nn.ReLU())
                count += 1
            elif isinstance(vgg[i], nn.MaxPool2d):
                temp_seq.add_module(str(count), nn.Upsample(scale_factor=2))
                count += 1
            if i in layers:
                self.decoder.append(temp_seq)
                temp_seq  = nn.Sequential()
        self.decoder.append(temp_seq[:-1])    
        
    def forward(self, x, styles, interpolation_weights=None):
        if interpolation_weights is None:
            interpolation_weights = [1/len(styles)] * len(styles)
        y = x
        for i, layer in enumerate(self.decoder):
            y = layer(y)
            if self.transformers[i]:
                transformed_feature = []
                for style, interpolation_weight in zip(styles, interpolation_weights):
                    transformed_feature.append(self.transformers[i](y, style[i]) * interpolation_weight)
                y = sum(transformed_feature)
        return y

In [6]:
def network_train(content_dir):
    # set device
    device = torch.device('cuda')
    # get network
    network = AvatarNet([1, 6, 11, 20]).to(device)
    # get data set
    data_set = ImageFolder(content_dir, 176, 176, True)
    # get loss calculator
    loss_network = Encoder([1, 6, 11, 20]).to(device)
    mse_loss = torch.nn.MSELoss(reduction='mean').to(device)
    loss_seq = {'total':[], 'image':[], 'feature':[], 'tv':[]}
    # get optimizer
    for param in network.encoder.parameters():
        param.requires_grad = False
    optimizer = torch.optim.Adam(network.decoder.parameters(), lr=1e-3)
    # training
    for iteration in range(20000):
        data_loader = torch.utils.data.DataLoader(data_set, batch_size=16, shuffle=True)
        input_image = next(iter(data_loader)).to(device)
        if (iteration + 1) % 100 == 0:
            print(iteration+1, 
                20000)
            torch.save({'iteration': iteration+1,
                'state_dict': network.state_dict()
               },
                './'+'model.pth')
    return network

In [7]:
def _normalizer(denormalize=False):
    MEAN = [0.485, 0.456, 0.406]
    STD = [0.229, 0.224, 0.225]    
    if denormalize:
        MEAN = [-mean/std for mean, std in zip(MEAN, STD)]
        STD = [1/std for std in STD]
    return transforms.Normalize(mean=MEAN, std=STD)

In [8]:
def _transformer(imsize=None, cropsize=None, cencrop=False):
    normalize = _normalizer()
    transformer = []
    if imsize:
        transformer.append(transforms.Resize(imsize))
    if cropsize:
        if cencrop:
            transformer.append(transforms.CenterCrop(cropsize))
        else:
            transformer.append(transforms.RandomCrop(cropsize))

    transformer.append(transforms.ToTensor())
    transformer.append(normalize)
    return transforms.Compose(transformer)

In [9]:
network_train('./image/')

100 20000
200 20000


KeyboardInterrupt: 