Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add -normalize_gradients parameter #84

Merged
merged 5 commits into from
Dec 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ path or a full absolute path.
when using ADAM you will probably need to play with other parameters to get good results, especially
the style weight, content weight, and learning rate.
* `-learning_rate`: Learning rate to use with the ADAM optimizer. Default is 1e1.
* `-normalize_gradients`: If this flag is present, style and content gradients from each layer will be L1 normalized.

**Output options**:
* `-output_image`: Name of the output image. Default is `out.png`.
Expand Down Expand Up @@ -313,4 +314,4 @@ If you find this code useful for your research, please cite:
journal = {GitHub repository},
howpublished = {\url{https://github.com/ProGamerGov/neural-style-pt}},
}
```
```
45 changes: 34 additions & 11 deletions neural_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
parser.add_argument("-content_weight", type=float, default=5e0)
parser.add_argument("-style_weight", type=float, default=1e2)
parser.add_argument("-normalize_weights", action='store_true')
parser.add_argument("-normalize_gradients", action='store_true')
parser.add_argument("-tv_weight", type=float, default=1e-3)
parser.add_argument("-num_iterations", type=int, default=1000)
parser.add_argument("-init", choices=['random', 'image'], default='random')
Expand Down Expand Up @@ -121,13 +122,13 @@ def main():

if layerList['C'][c] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = ContentLoss(params.content_weight)
loss_module = ContentLoss(params.content_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
content_losses.append(loss_module)

if layerList['C'][c] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = StyleLoss(params.style_weight)
loss_module = StyleLoss(params.style_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
style_losses.append(loss_module)
c+=1
Expand All @@ -137,14 +138,14 @@ def main():

if layerList['R'][r] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = ContentLoss(params.content_weight)
loss_module = ContentLoss(params.content_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
content_losses.append(loss_module)
next_content_idx += 1

if layerList['R'][r] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = StyleLoss(params.style_weight)
loss_module = StyleLoss(params.style_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
style_losses.append(loss_module)
next_style_idx += 1
Expand Down Expand Up @@ -339,15 +340,15 @@ def preprocess(image_name, image_size):
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])])
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0)
tensor = Normalize(rgb2bgr(Loader(image) * 255)).unsqueeze(0)
return tensor


# Undo the above preprocessing.
def deprocess(output_tensor):
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])])
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 255
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu())
Expand Down Expand Up @@ -399,18 +400,36 @@ def normalize_weights(content_losses, style_losses):
i.strength = i.strength / max(i.target.size())


# Scale gradients in the backward pass
class ScaleGradients(torch.autograd.Function):
@staticmethod
def forward(self, input_tensor, strength):
self.strength = strength
return input_tensor

@staticmethod
def backward(self, grad_output):
grad_input = grad_output.clone()
grad_input = grad_input / (torch.norm(grad_input, keepdim=True) + 1e-8)
return grad_input * self.strength * self.strength, None


# Define an nn Module to compute content loss
class ContentLoss(nn.Module):

def __init__(self, strength):
def __init__(self, strength, normalize):
super(ContentLoss, self).__init__()
self.strength = strength
self.crit = nn.MSELoss()
self.mode = 'None'
self.normalize = normalize

def forward(self, input):
if self.mode == 'loss':
self.loss = self.crit(input, self.target) * self.strength
loss = self.crit(input, self.target)
if self.normalize:
loss = ScaleGradients.apply(loss, self.strength)
self.loss = loss * self.strength
elif self.mode == 'capture':
self.target = input.detach()
return input
Expand All @@ -427,14 +446,15 @@ def forward(self, input):
# Define an nn Module to compute style loss
class StyleLoss(nn.Module):

def __init__(self, strength):
def __init__(self, strength, normalize):
super(StyleLoss, self).__init__()
self.target = torch.Tensor()
self.strength = strength
self.gram = GramMatrix()
self.crit = nn.MSELoss()
self.mode = 'None'
self.blend_weight = None
self.normalize = normalize

def forward(self, input):
self.G = self.gram(input)
Expand All @@ -447,7 +467,10 @@ def forward(self, input):
else:
self.target = self.target.add(self.blend_weight, self.G.detach())
elif self.mode == 'loss':
self.loss = self.strength * self.crit(self.G, self.target)
loss = self.crit(self.G, self.target)
if self.normalize:
loss = ScaleGradients.apply(loss, self.strength)
self.loss = self.strength * loss
return input


Expand All @@ -465,4 +488,4 @@ def forward(self, input):


if __name__ == "__main__":
main()
main()