In [7]:
import os

from torch.utils.data import DataLoader
from easydict import EasyDict as edict
from tensorboardX import SummaryWriter
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
from torch import nn
import numpy as np
import torch

from dataset import dataset
from dataset import sampler
from net.model import *

In [8]:
args                    = edict({})
args.lr                 = 1e-4
args.device             = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.epochs             = 160_000
args.lr_decay           = 5e-5
args.save_dir           = 'assets/results'
args.n_threads          = 16
args.batch_size         = 8 
args.style_path         = 'assets/styles'
args.style_weight       = 10.0

## reference : https://drive.google.com/file/d/1EpkBA2K2eYILDSyPTt0fztz59UjAIpZU/view
args.encoder_path       = 'assets/vgg_normalised.pth'
args.contents_path      = 'assets/contents'
args.content_weight     = 1.0
args.save_ckpt_interval = 10_000

os.makedirs(          args.save_dir, exist_ok = True)
os.makedirs(f'{args.save_dir}/logs', exist_ok = True)
writer = SummaryWriter(log_dir = f'{args.save_dir}/logs')

In [9]:
decoder = Decoder()
encoder = Encoder()
encoder.load_state_dict(torch.load(args.encoder_path), strict = False)

encoder = nn.Sequential(*list(encoder.children())[:31])
model   = Net(encoder, decoder).to(args.device)
model.train()

print('모델 구성 완.')

모델 구성 완.


In [10]:
def adjust_lr(optimizer, iter_cnt):

    lr = args.lr / (1.0 + args.lr_decay * iter_cnt)
    for param in optimizer.param_groups: param['lr'] = lr

In [11]:
content_tf = dataset.train_transform()
style_tf   = dataset.train_transform()

content_dataset = dataset.StyleTransferDataset(args.contents_path, content_tf)
style_dataset   = dataset.StyleTransferDataset(   args.style_path,   style_tf)

content_iter    = iter(DataLoader(content_dataset, batch_size = args.batch_size,
                                  sampler = sampler.InfiniteSamplerWrapper(content_dataset),
                                  num_workers = args.n_threads))

style_iter      = iter(DataLoader(style_dataset, batch_size = args.batch_size,
                                  sampler = sampler.InfiniteSamplerWrapper(style_dataset),
                                  num_workers = args.n_threads))

optimizer       = torch.optim.Adam(model.decoder.parameters(), lr = args.lr)

In [12]:
for epoch in tqdm(range(1, args.epochs + 1)):

    adjust_lr(optimizer, iter_cnt = epoch)
    content_images = next(content_iter).to(args.device)
    style_images   = next(style_iter).to(args.device)

    loss_c, loss_s = model(content_images, style_images)
    loss_c         = args.content_weight * loss_c
    loss_s         =   args.style_weight * loss_s
    loss           = loss_c + loss_s

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    writer.add_scalar('content loss', loss_c.item(), epoch)
    writer.add_scalar(  'style loss', loss_s.item(), epoch)

    if epoch % args.save_ckpt_interval == 0 or epoch == args.epochs:

        print(f'content loss : {loss_c.item():.3f}')
        print(f'  style loss : {loss_s.item():.3f}')
        print(f'        loss : {loss.item():.3f}')
        state_dict = model.decoder.state_dict()
        for k in state_dict.keys(): state_dict[k] = state_dict[k].to(torch.device('cpu'))

        torch.save(state_dict, f'{args.save_dir}/decoder_{str(epoch).zfill(6)}.pth')

writer.close()

  6%|▋         | 10002/160000 [10:27<2:42:12, 15.41it/s]

content loss : 0.000
  style loss : 0.301
        loss : 0.301


 13%|█▎        | 20002/160000 [20:52<2:31:58, 15.35it/s]

content loss : 0.000
  style loss : 0.320
        loss : 0.320


 19%|█▉        | 30002/160000 [31:18<2:19:45, 15.50it/s]

content loss : 0.000
  style loss : 0.287
        loss : 0.287


 25%|██▌       | 40002/160000 [41:44<2:08:41, 15.54it/s]

content loss : 0.000
  style loss : 0.272
        loss : 0.272


 31%|███▏      | 50002/160000 [52:08<1:58:20, 15.49it/s]

content loss : 0.000
  style loss : 0.275
        loss : 0.275


 38%|███▊      | 60002/160000 [1:02:35<1:47:28, 15.51it/s]

content loss : 0.000
  style loss : 0.320
        loss : 0.320


 44%|████▍     | 70002/160000 [1:13:03<1:37:19, 15.41it/s]

content loss : 0.000
  style loss : 0.262
        loss : 0.262


 50%|█████     | 80002/160000 [1:23:31<1:26:05, 15.49it/s]

content loss : 0.000
  style loss : 0.286
        loss : 0.286


 56%|█████▋    | 90002/160000 [1:33:57<1:15:13, 15.51it/s]

content loss : 0.000
  style loss : 0.319
        loss : 0.319


 63%|██████▎   | 100002/160000 [1:44:23<1:04:27, 15.51it/s]

content loss : 0.000
  style loss : 0.287
        loss : 0.287


 69%|██████▉   | 110002/160000 [1:54:53<55:48, 14.93it/s]  

content loss : 0.000
  style loss : 0.290
        loss : 0.290


 75%|███████▌  | 120002/160000 [2:05:21<42:59, 15.51it/s]

content loss : 0.000
  style loss : 0.326
        loss : 0.326


 81%|████████▏ | 130002/160000 [2:15:49<32:21, 15.45it/s]

content loss : 0.000
  style loss : 0.297
        loss : 0.297


 88%|████████▊ | 140002/160000 [2:26:17<21:54, 15.22it/s]

content loss : 0.000
  style loss : 0.287
        loss : 0.287


 94%|█████████▍| 150002/160000 [2:36:45<10:46, 15.46it/s]

content loss : 0.000
  style loss : 0.282
        loss : 0.282


100%|██████████| 160000/160000 [2:47:14<00:00, 15.94it/s]

content loss : 0.000
  style loss : 0.294
        loss : 0.294



