forked from fastai/fastai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cyclegan.py
186 lines (163 loc) · 9.54 KB
/
cyclegan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from ..torch_core import *
from ..layers import *
from ..callback import *
from ..basic_train import Learner, LearnerCallback
__all__ = ['CycleGAN', 'CycleGanLoss', 'AdaptiveLoss', 'CycleGANTrainer']
def convT_norm_relu(ch_in:int, ch_out:int, norm_layer:nn.Module, ks:int=3, stride:int=2, bias:bool=True):
return [nn.ConvTranspose2d(ch_in, ch_out, kernel_size=ks, stride=stride, padding=1, output_padding=1, bias=bias),
norm_layer(ch_out), nn.ReLU(True)]
def pad_conv_norm_relu(ch_in:int, ch_out:int, pad_mode:str, norm_layer:nn.Module, ks:int=3, bias:bool=True,
pad=1, stride:int=1, activ:bool=True, init:Callable=nn.init.kaiming_normal_)->List[nn.Module]:
layers = []
if pad_mode == 'reflection': layers.append(nn.ReflectionPad2d(pad))
elif pad_mode == 'border': layers.append(nn.ReplicationPad2d(pad))
p = pad if pad_mode == 'zeros' else 0
conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=p, stride=stride, bias=bias)
if init:
init(conv.weight)
if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)
layers += [conv, norm_layer(ch_out)]
if activ: layers.append(nn.ReLU(inplace=True))
return layers
class ResnetBlock(nn.Module):
def __init__(self, dim:int, pad_mode:str='reflection', norm_layer:nn.Module=None, dropout:float=0., bias:bool=True):
super().__init__()
assert pad_mode in ['zeros', 'reflection', 'border'], f'padding {pad_mode} not implemented.'
norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)
layers = pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias)
if dropout != 0: layers.append(nn.Dropout(dropout))
layers += pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias, activ=False)
self.conv_block = nn.Sequential(*layers)
def forward(self, x): return x + self.conv_block(x)
def resnet_generator(ch_in:int, ch_out:int, n_ftrs:int=64, norm_layer:nn.Module=None,
dropout:float=0., n_blocks:int=6, pad_mode:str='reflection')->nn.Module:
norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)
bias = (norm_layer == nn.InstanceNorm2d)
layers = pad_conv_norm_relu(ch_in, n_ftrs, 'reflection', norm_layer, pad=3, ks=7, bias=bias)
for i in range(2):
layers += pad_conv_norm_relu(n_ftrs, n_ftrs *2, 'zeros', norm_layer, stride=2, bias=bias)
n_ftrs *= 2
layers += [ResnetBlock(n_ftrs, pad_mode, norm_layer, dropout, bias) for _ in range(n_blocks)]
for i in range(2):
layers += convT_norm_relu(n_ftrs, n_ftrs//2, norm_layer, bias=bias)
n_ftrs //= 2
layers += [nn.ReflectionPad2d(3), nn.Conv2d(n_ftrs, ch_out, kernel_size=7, padding=0), nn.Tanh()]
return nn.Sequential(*layers)
def conv_norm_lr(ch_in:int, ch_out:int, norm_layer:nn.Module=None, ks:int=3, bias:bool=True, pad:int=1, stride:int=1,
activ:bool=True, slope:float=0.2, init:Callable=nn.init.kaiming_normal_)->List[nn.Module]:
conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=pad, stride=stride, bias=bias)
if init:
init(conv.weight)
if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)
layers = [conv]
if norm_layer is not None: layers.append(norm_layer(ch_out))
if activ: layers.append(nn.LeakyReLU(slope, inplace=True))
return layers
def critic(ch_in:int, n_ftrs:int=64, n_layers:int=3, norm_layer:nn.Module=None, sigmoid:bool=False)->nn.Module:
norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)
bias = (norm_layer == nn.InstanceNorm2d)
layers = conv_norm_lr(ch_in, n_ftrs, ks=4, stride=2, pad=1)
for i in range(n_layers-1):
new_ftrs = 2*n_ftrs if i <= 3 else n_ftrs
layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=2, pad=1, bias=bias)
n_ftrs = new_ftrs
new_ftrs = 2*n_ftrs if n_layers <=3 else n_ftrs
layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=1, pad=1, bias=bias)
layers.append(nn.Conv2d(new_ftrs, 1, kernel_size=4, stride=1, padding=1))
if sigmoid: layers.append(nn.Sigmoid())
return nn.Sequential(*layers)
class CycleGAN(nn.Module):
def __init__(self, ch_in:int, ch_out:int, n_features:int=64, disc_layers:int=3, gen_blocks:int=6, lsgan:bool=True,
drop:float=0., norm_layer:nn.Module=None):
super().__init__()
self.D_A = critic(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)
self.D_B = critic(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)
self.G_A = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)
self.G_B = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)
#G_A: takes real input B and generates fake input A
#G_B: takes real input A and generates fake input B
#D_A: trained to make the difference between real input A and fake input A
#D_B: trained to make the difference between real input B and fake input B
def forward(self, real_A, real_B):
fake_A, fake_B = self.G_A(real_B), self.G_B(real_A)
if not self.training: return torch.cat([fake_A[:,None],fake_B[:,None]], 1)
idt_A, idt_B = self.G_A(real_A), self.G_B(real_B)
return [fake_A, fake_B, idt_A, idt_B]
class AdaptiveLoss(nn.Module):
def __init__(self, crit):
super().__init__()
self.crit = crit
def forward(self, output, target:bool):
targ = output.new_ones(*output.size()) if target else output.new_zeros(*output.size())
return self.crit(output, targ)
class CycleGanLoss(nn.Module):
def __init__(self, cgan:nn.Module, lambda_A:float=10., lambda_B:float=10, lambda_idt:float=0.5, lsgan:bool=True):
super().__init__()
self.cgan,self.l_A,self.l_B,self.l_idt = cgan,lambda_A,lambda_B,lambda_idt
#self.crit = F.mse_loss if lsgan else F.binary_cross_entropy
self.crit = AdaptiveLoss(F.mse_loss if lsgan else F.binary_cross_entropy)
def set_input(self, input):
self.real_A,self.real_B = input
def forward(self, output, target):
fake_A, fake_B, idt_A, idt_B = output
#Generators should return identity on the datasets they try to convert to
idt_loss = self.l_idt * (self.l_B * F.l1_loss(idt_A, self.real_B) + self.l_A * F.l1_loss(idt_B, self.real_A))
#Generators are trained to trick the critics so the following should be ones
gen_loss = self.crit(self.cgan.D_A(fake_A), True) + self.crit(self.cgan.D_B(fake_B), True)
#Cycle loss
cycle_loss = self.l_A * F.l1_loss(self.cgan.G_A(fake_B), self.real_A)
cycle_loss += self.l_B * F.l1_loss(self.cgan.G_B(fake_A), self.real_B)
self.metrics = [idt_loss, gen_loss, cycle_loss]
return idt_loss + gen_loss + cycle_loss
class CycleGANTrainer(LearnerCallback):
"`LearnerCallback` that handles cycleGAN Training."
_order=-20
def _set_trainable(self, D_A=False, D_B=False):
gen = (not D_A) and (not D_B)
requires_grad(self.learn.model.G_A, gen)
requires_grad(self.learn.model.G_B, gen)
requires_grad(self.learn.model.D_A, D_A)
requires_grad(self.learn.model.D_B, D_B)
if not gen:
self.opt_D_A.lr, self.opt_D_A.mom = self.learn.opt.lr, self.learn.opt.mom
self.opt_D_A.wd, self.opt_D_A.beta = self.learn.opt.wd, self.learn.opt.beta
self.opt_D_B.lr, self.opt_D_B.mom = self.learn.opt.lr, self.learn.opt.mom
self.opt_D_B.wd, self.opt_D_B.beta = self.learn.opt.wd, self.learn.opt.beta
def on_train_begin(self, **kwargs):
"Create the various optimizers."
self.G_A,self.G_B = self.learn.model.G_A,self.learn.model.G_B
self.D_A,self.D_B = self.learn.model.D_A,self.learn.model.D_B
self.crit = self.learn.loss_func.crit
self.opt_G = self.learn.opt.new([nn.Sequential(*flatten_model(self.G_A), *flatten_model(self.G_B))])
self.opt_D_A = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_A))])
self.opt_D_B = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_B))])
self.learn.opt.opt = self.opt_G.opt
self._set_trainable()
self.names = ['idt_loss', 'gen_loss', 'cyc_loss', 'da_loss', 'db_loss']
self.learn.recorder.no_val=True
self.learn.recorder.add_metric_names(self.names)
self.smootheners = {n:SmoothenValue(0.98) for n in self.names}
def on_batch_begin(self, last_input, **kwargs):
"Register the `last_input` in the loss function."
self.learn.loss_func.set_input(last_input)
def on_batch_end(self, last_input, last_output, **kwargs):
"Steps through the generators then each of the critics."
self.G_A.zero_grad(); self.G_B.zero_grad()
fake_A, fake_B = last_output[0].detach(), last_output[1].detach()
real_A, real_B = last_input
self._set_trainable(D_A=True)
self.D_A.zero_grad()
loss_D_A = 0.5 * (self.crit(self.D_A(real_A), True) + self.crit(self.D_A(fake_A), False))
loss_D_A.backward()
self.opt_D_A.step()
self._set_trainable(D_B=True)
self.D_B.zero_grad()
loss_D_B = 0.5 * (self.crit(self.D_B(real_B), True) + self.crit(self.D_B(fake_B), False))
loss_D_B.backward()
self.opt_D_B.step()
self._set_trainable()
metrics = self.learn.loss_func.metrics + [loss_D_A, loss_D_B]
for n,m in zip(self.names,metrics): self.smootheners[n].add_value(m)
def on_epoch_end(self, **kwargs):
"Put the various losses in the recorder."
self.learn.recorder.add_metrics([s.smooth for k,s in self.smootheners.items()])