-
Notifications
You must be signed in to change notification settings - Fork 0
/
wgan.py
121 lines (85 loc) · 3.98 KB
/
wgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import List, Any
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from discriminator import Discriminator
from generator import Generator
def weights_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)
class WGAN(pl.LightningModule):
def __init__(self, generator: Generator, discriminator: Discriminator):
super().__init__()
self.generator = generator.apply(weights_init)
self.discriminator = discriminator.apply(weights_init)
def forward(self, noise):
return self.generator(noise)
def get_gradient(self, real, fake, epsilon):
mixed_images = real * epsilon + fake * (1 - epsilon)
mixed_scores = self.discriminator(mixed_images)
gradient = torch.autograd.grad(
inputs=mixed_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True
)[0]
return gradient
def gradient_penalty(self, gradient):
gradient = gradient.view(len(gradient), -1)
gradient_norm = gradient.norm(2, dim=1)
penalty = torch.mean(torch.pow(gradient_norm - 1, 2))
return penalty
def gen_loss(self, fake_pred):
return - torch.mean(fake_pred)
def disc_loss(self, fake_pred, real_pred, gradient_penalty, c_lambda):
return torch.mean(fake_pred) - torch.mean(real_pred) + c_lambda * gradient_penalty
def train_generator(self, real, optimizer):
noise = self.generator.gen_noize(len(real), device=self.device)
fake = self.generator(noise)
fake_pred = self.discriminator(fake)
fake_loss = self.gen_loss(fake_pred)
self.manual_backward(fake_loss, optimizer)
optimizer.step()
self.log_dict({"g_loss": fake_loss})
return fake_loss
def train_discriminator(self, real, optimizer, repeats=5, c_lambda=10):
mean_disc_loss = 0
for _ in range(repeats):
optimizer.zero_grad()
real_pred = self.discriminator(real)
noise = self.generator.gen_noize(len(real), device=self.device)
fake = self.generator(noise)
fake_pred = self.discriminator(fake.detach())
epsilon = torch.rand(len(real), 1, 1, 1, requires_grad=True, device=self.device)
gradient = self.get_gradient(real, fake.detach(), epsilon)
gradient_penalty = self.gradient_penalty(gradient)
disc_loss = self.disc_loss(fake_pred, real_pred, gradient_penalty, c_lambda)
mean_disc_loss += disc_loss.item() / repeats
self.manual_backward(disc_loss, optimizer, retain_graph=True)
optimizer.step()
self.log_dict({"d_loss": mean_disc_loss})
return mean_disc_loss
def training_step(self, batch, batch_idx, optimizer_idx):
real, _ = batch
self.train_generator(real, self.optimizers()[0])
self.train_discriminator(real, self.optimizers()[1])
# if optimizer_idx == 0:
# return self.train_generator(real, optimizer)
# if optimizer_idx == 1:
# return self.train_discriminator(real, optimizer)
def configure_optimizers(self):
optimizer_gen = Adam(params=self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_disc = Adam(params=self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
return [optimizer_gen, optimizer_disc], []
def on_epoch_end(self):
noise = self.generator.gen_noize(device=self.device)
fake_pred = self.generator(noise)
img_grid = torchvision.utils.make_grid(fake_pred)
self.logger.experiment.add_image('generated_images', img_grid, self.current_epoch)