-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
96 lines (87 loc) · 3.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
from time import time
import cv2
import torch
from numpy import uint8
import data
import config as cfg
from networks import Generator, Discriminator
class GAN:
def __init__(self):
self.dataset = data.Dataset()
self.generator = Generator(cfg.g_channels).cuda(0)
self.discriminator = Discriminator(cfg.d_channels).cuda(0)
self.loss = {
"L1": torch.nn.L1Loss().cuda(0),
"BCE": torch.nn.BCELoss().cuda(0)
}
self.optimizers = {
"G": torch.optim.Adam(
self.generator.parameters(), cfg.lr, (0.5, 0.999)
),
"D": torch.optim.Adam(
self.discriminator.parameters(), cfg.lr, (0.5, 0.999)
)
}
def test(self, image, epoch):
self.generator.eval()
fake = self.generator(image)[0].cpu().detach()
cv2.imwrite(
os.path.join("log", f"{epoch}.png"),
cv2.cvtColor(uint8(
127.5 * (fake.numpy().transpose((1, 2, 0)) + 1)
), cv2.COLOR_BGR2RGB)
)
def train(self):
if "log" not in os.listdir():
os.mkdir("log")
test_image = torch.unsqueeze(self.dataset.normalize(
self.dataset.load(os.path.join("data", "test.png"))
), 0).cuda(0)
start_time = time()
self.discriminator.train()
for epoch in range(cfg.epoch):
self.generator.train()
for batch, (image, target) in enumerate(data.DataLoader(
self.dataset, cfg.batch_size, True, drop_last=True
)):
fake = self.generator(image)
self.optimizers["D"].zero_grad()
predict = [
self.discriminator(image, target),
self.discriminator(image, fake.detach())
]
loss = sum([
self.loss["BCE"](predict[0], torch.ones_like(predict[0])),
self.loss["BCE"](predict[1], torch.zeros_like(predict[1]))
]) / 2
loss.backward()
self.optimizers["D"].step()
self.optimizers["G"].zero_grad()
predict = self.discriminator(image, fake)
loss = (cfg.l1_lambda * self.loss["L1"](fake, target) +
self.loss["BCE"](predict, torch.ones_like(predict)))
loss.backward()
self.optimizers["G"].step()
self.progress(epoch, batch + 1, start_time)
self.test(test_image, epoch + 1)
torch.save(
self.generator.state_dict(),
os.path.join("networks", "generator.pt")
)
torch.save(
self.discriminator.state_dict(),
os.path.join("networks", "discriminator.pt")
)
def progress(self, epoch, batch, start_time):
step = len(self.dataset) // cfg.batch_size
total = cfg.epoch * step
complete = epoch * step + batch
eta = round((time() - start_time) * (total - complete) / complete)
print("\rTraining: [{}>{}] {:.2f}% eta: {:02}:{:02}:{:02}".format(
'-' * epoch, '.' * (cfg.epoch - epoch - 1), 100 * complete / total,
eta // 3600, (eta % 3600) // 60, eta % 60
), end="")
if __name__ == "__main__":
pixel2pixel = GAN()
pixel2pixel.train()