New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve multi scale loss #27
Comments
In other words, what I am suggesting is to modify the the forward() function of FlowNetS as follows:
|
Hmm, That could be a good thing for sparse, but I don't like the idea of treating FlowNetSUp as a different network. You should be able to use a pretrained FlowNetS as a FlowNetSUp as you want without having to convert it. I get the problem and its solution, but it should be in the learning process rather than in the network definition that we try to learn full res flow. I'll see what can be done nicely :) |
Sure. It should be possible to put the upsampling outside the network definition - just before calling the loss function. |
Fixed |
The current multi scale never computes the loss at the native rsolution. This is because the highest resolution of FlowNetS (flow2) is smaller then the target or input resolution.
The problem is more sever in the case of sparse target (eg. KITTI) since we don't really use an accurate resampling of the of sparse target, but use max pooling instead. I believe this may be because pytorch may not have nearest neighbor resizing with support for flexible output sizes. Even if we had nearest neighbor resizing, that would be inaccurate too.
A quick fix for this would be to do a bilinear upsampling at the output for FlowNetS. Then the error on the first resolution would be computed without any resizing of target.
The following derivative of FlowNetS which I call as FlowNetUp does the same thing. There is really no need to define a new class, this can be incorporated into FlowNetS itself.
If the above line of reasoning is correct then this change should provide improved training and hence better accuracy.
file FlowNetSUp.py
import torch
import torch.nn as nn
from torch.nn.init import kaiming_normal
import math
from .FlowNetS import FlowNetS
all = [
'FlowNetSUp', 'flownets_up', 'flownets_up_bn'
]
class FlowNetSUp(FlowNetS):
def flownets_up(path=None):
model = FlowNetSUp(batchNorm=False)
if path is not None:
data = torch.load(path)
if 'state_dict' in data.keys():
model.load_state_dict(data['state_dict'])
else:
model.load_state_dict(data)
return model
def flownets_up_bn(path=None):
model = FlowNetSUp(batchNorm=True)
if path is not None:
data = torch.load(path)
if 'state_dict' in data.keys():
model.load_state_dict(data['state_dict'])
else:
model.load_state_dict(data)
return model
The text was updated successfully, but these errors were encountered: