Skip to content

Commit

Permalink
merge dis
Browse files Browse the repository at this point in the history
  • Loading branch information
hytseng0509 committed Jul 22, 2018
1 parent e80906c commit 59439cb
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 46 deletions.
20 changes: 10 additions & 10 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import torch
from torch.autograd import Variable
import torch.nn as nn
#import itertools
import numpy as np

class DRIT(nn.Module):
def __init__(self, opts):
Expand All @@ -15,14 +13,16 @@ def __init__(self, opts):
self.concat = opts.concat

# discriminators
'''self.disA = networks.Dis(opts.input_dim_a)
self.disB = networks.Dis(opts.input_dim_b)
self.disA2 = networks.Dis(opts.input_dim_a)
self.disB2 = networks.Dis(opts.input_dim_b)'''
self.disA = networks.MultiScaleDis(opts.input_dim_a)
self.disB = networks.MultiScaleDis(opts.input_dim_b)
self.disA2 = networks.MultiScaleDis(opts.input_dim_a)
self.disB2 = networks.MultiScaleDis(opts.input_dim_b)
if opts.dis_scale > 1:
self.disA = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disB = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disA2 = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disB2 = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
else:
self.disA = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disB = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disA2 = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disB2 = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
self.disContent = networks.Dis_content()

# encoders
Expand Down
91 changes: 57 additions & 34 deletions src/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
import torch.nn.functional as F

####################################################################
#------------------------- Discriminators --------------------------
#------------------------- Discriminators --------------------------
####################################################################
class Dis_content(nn.Module):
def __init__(self):
super(Dis_content, self).__init__()
model = []
model += [LeakyReLUINSConv2d(256, 256, kernel_size=7, stride=2, padding=1)]
model += [LeakyReLUINSConv2d(256, 256, kernel_size=7, stride=2, padding=1)]
model += [LeakyReLUINSConv2d(256, 256, kernel_size=7, stride=2, padding=1)]
model += [LeakyReLUConv2d(256, 256, kernel_size=7, stride=2, padding=1, norm='Instance')]
model += [LeakyReLUConv2d(256, 256, kernel_size=7, stride=2, padding=1, norm='Instance')]
model += [LeakyReLUConv2d(256, 256, kernel_size=7, stride=2, padding=1, norm='Instance')]
model += [LeakyReLUConv2d(256, 256, kernel_size=4, stride=1, padding=0)]
model += [nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0)]
self.model = nn.Sequential(*model)
Expand All @@ -27,22 +27,25 @@ def forward(self, x):
return outs

class MultiScaleDis(nn.Module):
def __init__(self, input_dim, n_scale=3, n_layer=4):
def __init__(self, input_dim, n_scale=3, n_layer=4, norm='None', sn=False):
super(MultiScaleDis, self).__init__()
ch = 64
self.downsample = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
self.Diss = nn.ModuleList()
for _ in range(n_scale):
self.Diss.append(self._make_net(ch, input_dim, n_layer))
self.Diss.append(self._make_net(ch, input_dim, n_layer, norm, sn))

def _make_net(self, ch, input_dim, n_layer):
def _make_net(self, ch, input_dim, n_layer, norm, sn):
model = []
model += [LeakyReLUINSConv2d(input_dim, ch, 4, 2, 1)]
model += [LeakyReLUConv2d(input_dim, ch, 4, 2, 1, norm, sn)]
tch = ch
for _ in range(1, n_layer):
model += [LeakyReLUINSConv2d(tch, tch * 2, 4, 2, 1)]
model += [LeakyReLUConv2d(tch, tch * 2, 4, 2, 1, norm, sn)]
tch *= 2
model += [nn.Conv2d(tch, 1, 1, 1, 0)]
if sn:
model += [spectral_norm(nn.Conv2d(tch, 1, 1, 1, 0))]
else:
model += [nn.Conv2d(tch, 1, 1, 1, 0)]
return nn.Sequential(*model)

def forward(self, x):
Expand All @@ -53,39 +56,39 @@ def forward(self, x):
return outs

class Dis(nn.Module):
def __init__(self, input_dim):
def __init__(self, input_dim, norm='None', sn=False):
super(Dis, self).__init__()
ch = 64
n_layer = 6
self.model_A = self._make_net(ch, input_dim, n_layer)
self.model = self._make_net(ch, input_dim, n_layer, norm, sn)

def _make_net(self, ch, input_dim, n_layer):
def _make_net(self, ch, input_dim, n_layer, norm, sn):
model = []
model += [LeakyReLUINSConv2d(input_dim, ch, kernel_size=3, stride=2, padding=1)] #16
#model += [Spectral_LeakyReLUConv2d(input_dim, ch, kernel_size=3, stride=2, padding=1)] #16
model += [LeakyReLUConv2d(input_dim, ch, kernel_size=3, stride=2, padding=1, norm=norm, sn=sn)] #16
tch = ch
for i in range(1, n_layer-1):
model += [LeakyReLUINSConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1)] # 8
#model += [Spectral_LeakyReLUConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1)] # 8
model += [LeakyReLUConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1, norm=norm, sn=sn)] # 8
tch *= 2
model += [LeakyReLUConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1)] # 1
model += [LeakyReLUConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1, norm='None', sn=sn)] # 2
tch *= 2
model += [nn.Conv2d(tch, 1, kernel_size=1, stride=1, padding=0)] # 1
#model += [spectral_norm(nn.Conv2d(tch, 1, kernel_size=1, stride=1, padding=0))] # 1
if sn:
model += [spectral_norm(nn.Conv2d(tch, 1, kernel_size=1, stride=1, padding=0))] # 1
else:
model += [nn.Conv2d(tch, 1, kernel_size=1, stride=1, padding=0)] # 1
return nn.Sequential(*model)

def cuda(self,gpu):
self.model_A.cuda(gpu)
self.model.cuda(gpu)

def forward(self, x_A):
out_A = self.model_A(x_A)
out_A = self.model(x_A)
out_A = out_A.view(-1)
outs_A = []
outs_A.append(out_A)
return outs_A

####################################################################
#---------------------------- Encoders -----------------------------
#---------------------------- Encoders -----------------------------
####################################################################
class E_content(nn.Module):
def __init__(self, input_dim_a, input_dim_b):
Expand Down Expand Up @@ -129,7 +132,7 @@ def forward_a(self, xa):
outputA = self.conv_share(outputA)
return outputA

def forward(self, xb):
def forward_b(self, xb):
outputB = self.convB(xb)
outputB = self.conv_share(outputB)
return outputB
Expand Down Expand Up @@ -178,7 +181,7 @@ def forward_a(self, xa):
output_A = xa.view(xa.size(0), -1)
return output_A

def forward(self, xb):
def forward_b(self, xb):
xb = self.model_b(xb)
output_B = xb.view(xb.size(0), -1)
return output_B
Expand Down Expand Up @@ -229,14 +232,15 @@ def forward_a(self, xa):
outputVar_A = self.fcVar_A(conv_flat_A)
return output_A, outputVar_A

def forward(self, xb):
def forward_b(self, xb):
x_conv_B = self.conv_B(xb)
conv_flat_B = x_conv_B.view(xb.size(0), -1)
output_B = self.fc_B(conv_flat_B)
outputVar_B = self.fcVar_B(conv_flat_B)
return output_B, outputVar_B

####################################################################
#--------------------------- Generators ----------------------------
#--------------------------- Generators ----------------------------
####################################################################
class G(nn.Module):
def __init__(self, output_dim_a, output_dim_b, nz):
Expand All @@ -250,7 +254,7 @@ def __init__(self, output_dim_a, output_dim_b, nz):
self.decA2 = MisINSResBlock(tch, tch_add)
self.decA3 = MisINSResBlock(tch, tch_add)
self.decA4 = MisINSResBlock(tch, tch_add)

decA5 = []
decA5 += [ReLUINSConvTranspose2d(tch, tch//2, kernel_size=3, stride=2, padding=1, output_padding=1)]
tch = tch//2
Expand Down Expand Up @@ -385,7 +389,7 @@ def forward_b(self, x, z):
return out4

####################################################################
#------------------------- Basic Functions -------------------------
#------------------------- Basic Functions -------------------------
####################################################################
def get_scheduler(optimizer, opts, cur_ep=-1):
if opts.lr_policy == 'lambda':
Expand Down Expand Up @@ -441,7 +445,7 @@ def gaussian_weights_init(m):
m.weight.data.normal_(0.0, 0.02)

####################################################################
#-------------------------- Basic Blocks --------------------------
#-------------------------- Basic Blocks --------------------------
####################################################################
class LayerNorm(nn.Module):
def __init__(self, n_out, eps=1e-5, affine=True):
Expand Down Expand Up @@ -481,7 +485,24 @@ def forward(self, x):
out = self.conv(x) + self.shortcut(x)
return out

class Spectral_LeakyReLUConv2d(nn.Module):
class LeakyReLUConv2d(nn.Module):
def __init__(self, n_in, n_out, kernel_size, stride, padding=0, norm='None', sn=False):
super(LeakyReLUConv2d, self).__init__()
model = []
if sn:
model += [spectral_norm(nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True))]
else:
model += [nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)]
if 'norm' == 'Instance':
model += [nn.InstanceNorm2d(n_out, affine=False)]
model += [nn.LeakyReLU(inplace=True)]
self.model = nn.Sequential(*model)
self.model.apply(gaussian_weights_init)
#elif == 'Group'
def forward(self, x):
return self.model(x)

'''class Spectral_LeakyReLUConv2d(nn.Module):
def __init__(self, n_in, n_out, kernel_size, stride, padding=0):
super(Spectral_LeakyReLUConv2d, self).__init__()
model = []
Expand Down Expand Up @@ -513,7 +534,7 @@ def __init__(self, n_in, n_out, kernel_size, stride, padding=0):
self.model = nn.Sequential(*model)
self.model.apply(gaussian_weights_init)
def forward(self, x):
return self.model(x)
return self.model(x)'''

class ReLUINSConv2d(nn.Module):
def __init__(self, n_in, n_out, kernel_size, stride, padding=0):
Expand Down Expand Up @@ -604,6 +625,7 @@ def __init__(self, n_in, n_out, kernel_size, stride, padding, output_padding):
super(ReLUINSConvTranspose2d, self).__init__()
model = []
model += [nn.ConvTranspose2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=True)]
#model += [nn.LayerNorm()]
model += [LayerNorm(n_out)]
model += [nn.ReLU(inplace=True)]
self.model = nn.Sequential(*model)
Expand All @@ -613,7 +635,7 @@ def forward(self, x):


####################################################################
#--------------------- Spectral Normalization ---------------------
#--------------------- Spectral Normalization ---------------------
# This part of code is copied from pytorch master branch (0.5.0)
####################################################################
class SpectralNorm(object):
Expand Down Expand Up @@ -687,4 +709,5 @@ def remove_spectral_norm(module, name='weight'):
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))
raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))

8 changes: 7 additions & 1 deletion src/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ def __init__(self):
self.parser.add_argument('--display_freq', type=int, default=1, help='freq (iteration) of display')
self.parser.add_argument('--img_save_freq', type=int, default=50, help='freq (epoch) of saving images')
self.parser.add_argument('--model_save_freq', type=int, default=10, help='freq (epoch) of saving models')
self.parser.add_argument('--no_displayimg', action='store_true', help='specified if no dispaly')
self.parser.add_argument('--no_display_img', action='store_true', help='specified if no dispaly')

# training related
self.parser.add_argument('--concat', type=int, default=1, help='concatenate attribute features for translation')
self.parser.add_argument('--dis_scale', type=int, default=3, help='scale of discriminator')
self.parser.add_argument('--dis_norm', type=str, default='None', help='normalization layer in discriminator [None, Instance]')
self.parser.add_argument('--dis_spectral_norm', type=int, default=0, help='use spectral normalization in discriminator')
self.parser.add_argument('--lr_policy', type=str, default='lambda', help='type of learn rate decay')
self.parser.add_argument('--n_ep', type=int, default=1200, help='number of epochs') # 400 * d_iter
self.parser.add_argument('--n_ep_decay', type=int, default=600, help='epoch start decay learning rate, set -1 if no decay') # 200 * d_iter
Expand Down Expand Up @@ -63,6 +66,9 @@ def __init__(self):

# model related
self.parser.add_argument('--concat', type=int, default=1, help='concatenate attribute features for translation')
self.parser.add_argument('--dis_scale', type=int, default=3, help='scale of discriminator')
self.parser.add_argument('--dis_norm', type=str, default='None', help='normalization layer in discriminator [None, Instance]')
self.parser.add_argument('--dis_spectral_norm', type=int, default=0, help='use spectral normalization in discriminator')
self.parser.add_argument('--resume', type=str, default=None, help='specified the dir of saved models for resume the training')
self.parser.add_argument('--gpu', type=int, default=0, help='gpu')

Expand Down
2 changes: 1 addition & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def main():
model.update_EG()

# save to display file
if not opt.no_dispalyimg:
if not opts.no_display_img:
saver.write_display(total_it, model)

print('total_it: %d (ep %d, it %d), lr %08f' % (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
Expand Down

0 comments on commit 59439cb

Please sign in to comment.