New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch models to pytorch models, bug fix, CPU support, etc #18
Conversation
…layers according to the origin paper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea. Suggested changes to reduce confusion for end users.
demo.py
Outdated
parser.add_argument('--decoder3', default='./models/feature_invertor_conv3_1_mask.t7', help='Path to the decoder3') | ||
parser.add_argument('--decoder2', default='./models/feature_invertor_conv2_1_mask.t7', help='Path to the decoder2') | ||
parser.add_argument('--decoder1', default='./models/feature_invertor_conv1_1_mask.t7', help='Path to the decoder1') | ||
parser.add_argument('--model', default='./PhotoWCTModels/photo_wct.pth', help='Path to the PhotoWCT model') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For ease of use, change this to something like: help='Path to the PhotoWCT model. These are provided by the PhotoWCT submodule, please use git submodule update --init to pull.'
demo.py
Outdated
@@ -27,7 +20,8 @@ | |||
args = parser.parse_args() | |||
|
|||
# Load model | |||
p_wct = PhotoWCT(args) | |||
p_wct = PhotoWCT() | |||
p_wct.load_state_dict(torch.load(args.model)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For ease of use, use a try/except here with the exception message reminding the user to pull the submodule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Updated.
Hi, I think this PR is heavy and it's better to merge it before more PRs come in which may cause unexpected conflicts. @mingyuliutw @Yijunmaverick |
Merge #18 |
Processing speed and GPU performance memory wise is a bit worst than the original, was this supposed to improve GPU memory usage and overall time requirement too? |
This PR:
Convert torch models to pytorch models (listed in TODOs in the origin code) and
converter.py
shows how it was done. The pytorch model leaves in the submodulePhotoWCTModels
which makes it easier to download from a server as Move trained models out of Google Drive #15 suggested.The models are refactored into less and clear classes. The layers are named according to the origin paper.
Fix a bug in
Propagator
. It fails to process images with alpha channels because it does not open them with RGB mode.CPU support for PhotoWCT. PhotoWCT can work in CPU mode without using
.cuda()
. This could make it ~10x slower (not too slow yet) but more friendly for those without GPUs or GPUs with less memory as Look at the memory consumption: should be able to process HD size picture #17I'm sorry that some codes in
photo_wct.py
are changed by the (PEP8) code formatter, so not too much of them are actually modified.Current code are tested. They can work well as before.