In [1]:
import os

from torch.utils.data import DataLoader
from easydict import EasyDict as edict
from tensorboardX import SummaryWriter
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
from torch import nn
import numpy as np
import torch

from dataset import dataset
from dataset import sampler
from net.model import *

  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")


In [2]:
args                    = edict({})
args.lr                 = 1e-4
args.device             = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.epochs             = 160_000
args.lr_decay           = 5e-5
args.save_dir           = 'assets/results'
args.n_threads          = 16
args.batch_size         = 8 
args.style_path         = 'assets/styles'
args.style_weight       = 10.0

## reference : https://drive.google.com/file/d/1EpkBA2K2eYILDSyPTt0fztz59UjAIpZU/view
args.encoder_path       = 'misc/vgg_normalised.pth'
args.contents_path      = 'assets/contents'
args.content_weight     = 1.0
args.save_ckpt_interval = 10_000

os.makedirs(          args.save_dir, exist_ok = True)
os.makedirs(f'{args.save_dir}/logs', exist_ok = True)
writer = SummaryWriter(log_dir = f'{args.save_dir}/logs')

In [3]:
decoder = Decoder()
encoder = Encoder()
encoder.load_state_dict(torch.load(args.encoder_path), strict = False)

encoder = nn.Sequential(*list(encoder.children())[:31])
model   = Net(encoder, decoder).to(args.device)
model.train()

print('모델 구성 완.')

모델 구성 완.


In [4]:
def adjust_lr(optimizer, iter_cnt):

    lr = args.lr / (1.0 + args.lr_decay * iter_cnt)
    for param in optimizer.param_groups: param['lr'] = lr

In [5]:
content_tf = dataset.train_transform()
style_tf   = dataset.train_transform()

content_dataset = dataset.StyleTransferDataset(args.contents_path, content_tf)
style_dataset   = dataset.StyleTransferDataset(   args.style_path,   style_tf)

content_iter    = iter(DataLoader(content_dataset, batch_size = args.batch_size,
                                  sampler = sampler.InfiniteSamplerWrapper(content_dataset),
                                  num_workers = args.n_threads))

style_iter      = iter(DataLoader(style_dataset, batch_size = args.batch_size,
                                  sampler = sampler.InfiniteSamplerWrapper(style_dataset),
                                  num_workers = args.n_threads))

optimizer       = torch.optim.Adam(model.decoder.parameters(), lr = args.lr)

In [6]:
for epoch in tqdm(range(1, args.epochs + 1)):

    adjust_lr(optimizer, iter_cnt = epoch)
    content_images = next(content_iter).to(args.device)
    style_images   = next(style_iter).to(args.device)

    loss_c, loss_s = model(content_images, style_images)
    loss_c         = args.content_weight * loss_c
    loss_s         =   args.style_weight * loss_s
    loss           = loss_c + loss_s

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    writer.add_scalar('content loss', loss_c.item(), epoch)
    writer.add_scalar(  'style loss', loss_s.item(), epoch)

    if epoch % args.save_ckpt_interval == 0 or epoch == args.epochs:

        print(f'content loss : {loss_c.item():.3f}')
        print(f'  style loss : {loss_s.item():.3f}')
        print(f'        loss : {loss.item():.3f}')
        state_dict = model.decoder.state_dict()
        for k in state_dict.keys(): state_dict[k] = state_dict[k].to(torch.device('cpu'))

        torch.save(state_dict, f'{args.save_dir}/decoer_{str(epoch).zfill(6)}.pth.tar')

writer.close()

  2%|▏         | 3295/160000 [03:28<2:44:53, 15.84it/s]


KeyboardInterrupt: 