In [1]:
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils.data as data
from PIL import Image, ImageFile
from tensorboardX import SummaryWriter
from torchvision import transforms
from tqdm import tqdm
from sampler import InfiniteSamplerWrapper
from lib.depth_net import DepthV3
from pathlib import Path

import net

In [7]:
content_dir = "input/faces"
style_dir = "input/arts"
vgg_path = 'models/vgg_normalised.pth'
save_dir = './experiments'
log_dir = './logs'
lr = 1e-4
lr_decay = 5e-5
max_iter = 10000
batch_size = 8
style_weight = 10.0
content_weight = 1.0
depth_weight = 100.0
n_threads = 4
save_model_interval = 1000

In [8]:
def train_transform():
    transform_list = [
        transforms.Resize(size=(300, 300)),
        transforms.RandomCrop(256),
        transforms.ToTensor()
    ]
    return transforms.Compose(transform_list)

class FlatFolderDataset(data.Dataset):
    def __init__(self, root, transform):
        super(FlatFolderDataset, self).__init__()
        self.root = root
        self.paths = list(Path(self.root).glob('*.jpg'))
        self.transform = transform

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(str(path)).convert('RGB')
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.paths)

    def name(self):
        return 'FlatFolderDataset'
    
def adjust_learning_rate(optimizer, iteration_count, lr):
    """Imitating the original implementation"""
    new_lr = lr / (1.0 + lr_decay * iteration_count)
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lr
    return new_lr

In [9]:
device = torch.device('cuda')
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True, parents=True)
log_dir = Path(log_dir)
log_dir.mkdir(exist_ok=True, parents=True)
writer = SummaryWriter(log_dir=str(log_dir))

In [10]:
decoder = net.decoder
decoder.load_state_dict(torch.load("experiments/decoder_iter_62000.pth"))

vgg = net.vgg

vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:31])


depth_net = DepthV3((100, 100))
depth_net.load_state_dict(torch.load('models/depth_model_40_-1.076169682497328.pth'))
for param in depth_net.parameters():
    param.requires_grad = False
    
network = net.Net(vgg, decoder, depth_net)
network.train()
network.to(device);

In [11]:
content_tf = train_transform()
style_tf = train_transform()

content_dataset = FlatFolderDataset(content_dir, content_tf)
style_dataset = FlatFolderDataset(style_dir, style_tf)

content_iter = iter(data.DataLoader(
    content_dataset, batch_size=batch_size,
    sampler=InfiniteSamplerWrapper(content_dataset),
    num_workers=n_threads))

style_iter = iter(data.DataLoader(
    style_dataset, batch_size=batch_size,
    sampler=InfiniteSamplerWrapper(style_dataset),
    num_workers=n_threads))


In [None]:
optimizer = torch.optim.Adam(network.decoder.parameters(), lr=lr)

for i in tqdm(range(80001, 85001)):
    lr = adjust_learning_rate(optimizer, iteration_count=i, lr=lr)
    
    content_images = next(content_iter).to(device)
    style_images = next(style_iter).to(device)
    
    loss_c, loss_s, loss_d, _, _ = network(content_images, style_images)
    loss_c = content_weight * loss_c
    loss_s = style_weight * loss_s
    loss_d = depth_weight * loss_d
    loss = loss_c + loss_s + loss_d

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    writer.add_scalar('loss_content', loss_c.item(), i + 1)
    writer.add_scalar('loss_style', loss_s.item(), i + 1)

    if (i + 1) % save_model_interval == 0 or (i + 1) == max_iter:
        state_dict = net.decoder.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].to(torch.device('cpu'))
        torch.save(state_dict, save_dir /
                   'decoder_iter_{:d}.pth'.format(i + 1))
writer.close()

  6%|▌         | 283/5000 [05:22<1:29:28,  1.14s/it]