Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch models to pytorch models, bug fix, CPU support, etc #18

Merged
merged 8 commits into from Mar 2, 2018

Conversation

suquark
Copy link
Contributor

@suquark suquark commented Feb 25, 2018

This PR:

  • Convert torch models to pytorch models (listed in TODOs in the origin code) and converter.py shows how it was done. The pytorch model leaves in the submodule PhotoWCTModels which makes it easier to download from a server as Move trained models out of Google Drive #15 suggested.

  • The models are refactored into less and clear classes. The layers are named according to the origin paper.

  • Fix a bug in Propagator. It fails to process images with alpha channels because it does not open them with RGB mode.

  • CPU support for PhotoWCT. PhotoWCT can work in CPU mode without using .cuda(). This could make it ~10x slower (not too slow yet) but more friendly for those without GPUs or GPUs with less memory as Look at the memory consumption: should be able to process HD size picture #17

I'm sorry that some codes in photo_wct.py are changed by the (PEP8) code formatter, so not too much of them are actually modified.

Current code are tested. They can work well as before.

Copy link

@xerebz xerebz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea. Suggested changes to reduce confusion for end users.

demo.py Outdated
parser.add_argument('--decoder3', default='./models/feature_invertor_conv3_1_mask.t7', help='Path to the decoder3')
parser.add_argument('--decoder2', default='./models/feature_invertor_conv2_1_mask.t7', help='Path to the decoder2')
parser.add_argument('--decoder1', default='./models/feature_invertor_conv1_1_mask.t7', help='Path to the decoder1')
parser.add_argument('--model', default='./PhotoWCTModels/photo_wct.pth', help='Path to the PhotoWCT model')
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ease of use, change this to something like: help='Path to the PhotoWCT model. These are provided by the PhotoWCT submodule, please use git submodule update --init to pull.'

demo.py Outdated
@@ -27,7 +20,8 @@
args = parser.parse_args()

# Load model
p_wct = PhotoWCT(args)
p_wct = PhotoWCT()
p_wct.load_state_dict(torch.load(args.model))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ease of use, use a try/except here with the exception message reminding the user to pull the submodule.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Updated.

@suquark
Copy link
Contributor Author

suquark commented Feb 28, 2018

Hi, I think this PR is heavy and it's better to merge it before more PRs come in which may cause unexpected conflicts. @mingyuliutw @Yijunmaverick

@mingyuliutw
Copy link
Collaborator

Merge #18

@mingyuliutw mingyuliutw closed this Mar 2, 2018
@mingyuliutw mingyuliutw merged commit 7508715 into NVIDIA:master Mar 2, 2018
@moeedkundi
Copy link

Processing speed and GPU performance memory wise is a bit worst than the original, was this supposed to improve GPU memory usage and overall time requirement too?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants