In [35]:
import cv2
import torch
import numpy as np


In [36]:
import torch.nn as nn
import torch.nn.functional as F
from utils.tensorboard import TensorBoard
from Renderer.model import FCN
from Renderer.stroke_gen import *

In [37]:
writer = TensorBoard("../train_log/")
import torch.optim as optim

criterion = nn.MSELoss()
net = FCN()
optimizer = optim.Adam(net.parameters(), lr=3e-6)
batch_size = 64

use_cuda = torch.cuda.is_available()
step = 0

In [38]:
import random

def draw_rec(f, width=128):
    x0, y0, x1, y1, x2, y2, z0, z2, w0, w2 = f
    x1 = x0 + (x2 - x0) * x1
    y1 = y0 + (y2 - y0) * y1
    x0 = normal(x0, width * 2)
    x1 = normal(x1, width * 2)
    x2 = normal(x2, width * 2)
    y0 = normal(y0, width * 2)
    y1 = normal(y1, width * 2)
    y2 = normal(y2, width * 2)
    z0 = (int)(1 + z0 * width // 2)
    z2 = (int)(1 + z2 * width // 2)
    canvas = np.zeros([width * 2, width * 2]).astype('float32')
    tmp = 1. / 100
    for i in range(100):
        t = i * tmp
        x = (int)((1-t) * (1-t) * x0 + 2 * t * (1-t) * x1 + t * t * x2)
        y = (int)((1-t) * (1-t) * y0 + 2 * t * (1-t) * y1 + t * t * y2)
        z = (int)((1-t) * z0 + t * z2)
        w = (1-t) * w0 + t * w2
        p = random.uniform(0, 1)
#         if (p < r):
#             cv2.circle(canvas, (y, x), z, w, -1)
#         else:
        cv2.rectangle(canvas, (y-z, x-z), (y+z, x+z), w, -1)
    return 1 - cv2.resize(canvas, dsize=(width, width))

In [39]:
def save_model():
    if use_cuda:
        net.cpu()
    torch.save(net.state_dict(), "./renderer.pkl")
    if use_cuda:
        net.cuda()


def load_weights():
    pretrained_dict = torch.load("./renderer.pkl")
    model_dict = net.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    net.load_state_dict(model_dict)


# load_weights()
# 500000
while step < 50000:
    net.train()
    train_batch = []
    ground_truth = []
    for i in range(batch_size):
        f = np.random.uniform(0, 1, 10)
        train_batch.append(f)
        ground_truth.append(draw_rec(f))

    train_batch = torch.tensor(train_batch).float()
    ground_truth = torch.tensor(ground_truth).float()
    if use_cuda:
        net = net.cuda()
        train_batch = train_batch.cuda()
        ground_truth = ground_truth.cuda()
    gen = net(train_batch)
    optimizer.zero_grad()
    loss = criterion(gen, ground_truth)
    loss.backward()
    optimizer.step()
    print(step, loss.item())
#     if step < 200000:
#         lr = 1e-4
#     elif step < 400000:
#         lr = 1e-5
#     else:
#         lr = 1e-6
    if step < 20000:
        lr = 1e-4
    elif step < 40000:
        lr = 1e-5
    else:
        lr = 1e-6
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
    writer.add_scalar("train/loss", loss.item(), step)
    if step % 10 == 0:
        net.eval()
        gen = net(train_batch)
        loss = criterion(gen, ground_truth)
        writer.add_scalar("val/loss", loss.item(), step)
        for i in range(32):
            G = gen[i].cpu().data.numpy()
            GT = ground_truth[i].cpu().data.numpy()
            writer.add_image("train/gen{}.png".format(i), G, step)
            writer.add_image("train/ground_truth{}.png".format(i), GT, step)
    if step % 100 == 0:
        save_model()
    step += 1

0 0.19549666345119476
1 0.1973854899406433
2 0.19790826737880707
3 0.1990804523229599
4 0.19484663009643555
5 0.1915113478899002
6 0.19523634016513824
7 0.18891647458076477
8 0.19570937752723694
9 0.18840046226978302
10 0.20202884078025818
11 0.1901399940252304
12 0.19921591877937317
13 0.19563347101211548
14 0.19568216800689697
15 0.19558237493038177
16 0.19058682024478912
17 0.19623902440071106
18 0.19287648797035217
19 0.19244727492332458
20 0.19170454144477844
21 0.19337821006774902
22 0.19314181804656982
23 0.1942937672138214
24 0.19631287455558777
25 0.1920280009508133
26 0.19254088401794434
27 0.18856479227542877
28 0.19288265705108643
29 0.19884422421455383
30 0.19040675461292267
31 0.19036933779716492
32 0.19944025576114655
33 0.18976278603076935
34 0.19355599582195282
35 0.19024774432182312
36 0.19211550056934357
37 0.19310970604419708
38 0.1958107203245163
39 0.19083847105503082
40 0.1922377645969391
41 0.19029000401496887
42 0.18809325993061066
43 0.18611015379428864
44 0.1

345 0.04050657898187637
346 0.05645107477903366
347 0.04641180485486984
348 0.04318072274327278
349 0.05616465210914612
350 0.0459757000207901
351 0.05249868333339691
352 0.04690707102417946
353 0.045628707855939865
354 0.048415813595056534
355 0.04535304754972458
356 0.04519273340702057
357 0.0525292307138443
358 0.04341697692871094
359 0.047389715909957886
360 0.05315535143017769
361 0.05034352466464043
362 0.049909450113773346
363 0.04737238213419914
364 0.048947524279356
365 0.0446118526160717
366 0.0537588857114315
367 0.04061374440789223
368 0.05053449422121048
369 0.046651050448417664
370 0.05162510275840759
371 0.050863854587078094
372 0.04561108723282814
373 0.05478013679385185
374 0.05173374339938164
375 0.04505295306444168
376 0.04320273920893669
377 0.03850474953651428
378 0.05081522464752197
379 0.04620631784200668
380 0.046977896243333817
381 0.05296042189002037
382 0.03896210342645645
383 0.04905391111969948
384 0.04762332886457443
385 0.042212843894958496
386 0.03985425

682 0.02957393229007721
683 0.03061557002365589
684 0.02859896421432495
685 0.021533066406846046
686 0.031092621386051178
687 0.02728699892759323
688 0.028288288041949272
689 0.026681197807192802
690 0.029789432883262634
691 0.0244524534791708
692 0.026989281177520752
693 0.026073653250932693
694 0.023549504578113556
695 0.02546806074678898
696 0.025223301723599434
697 0.028753001242876053
698 0.02246514894068241
699 0.03324771299958229
700 0.02677524834871292
701 0.03230055049061775
702 0.026331409811973572
703 0.023044219240546227
704 0.029756218194961548
705 0.03284820169210434
706 0.027897313237190247
707 0.028239311650395393
708 0.028186488896608353
709 0.03129921853542328
710 0.02541663870215416
711 0.0247968602925539
712 0.021815210580825806
713 0.025837143883109093
714 0.025154639035463333
715 0.023728055879473686
716 0.028790924698114395
717 0.03154674172401428
718 0.02999228984117508
719 0.024854594841599464
720 0.02587105892598629
721 0.03288858383893967
722 0.02297795750200

KeyboardInterrupt: 