Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

models fail to decompress #25

Closed
AndrewZhao opened this issue Jan 29, 2018 · 22 comments
Closed

models fail to decompress #25

AndrewZhao opened this issue Jan 29, 2018 · 22 comments

Comments

@AndrewZhao
Copy link

Hello, i download the flownets_bn_EPE2.459.pth.tar and flownets_EPE1.951.pth.tar
And when i decompress it, the file is broken. Can you share the two files again?
Thanks very much.

@ClementPinard
Copy link
Owner

You don't have to decompress it, you just load them with the pretrained argument in the main file ! .pth.tar extension is mostly arbitrary.

@AndrewZhao
Copy link
Author

Got it. Thanks very much.

@ghost
Copy link

ghost commented Oct 31, 2018

how and where to load the .pth.tar file?

@ClementPinard
Copy link
Owner

ClementPinard commented Oct 31, 2018

in the command line options, there is a --pretrained option, where you can specify the filepath

python3 train.py --pretrained path_to_.pth.tar

Other than that, if you want to use it in another projet, ie in a notebook, you have to call the constructor function for you network (say flownets)

from models import flownets
weights = torch.load(path_to_.pth.tar)
my_model = flownets(data=weights)
dummy_img = torch.randn(1,6,100,100)
my_flow = my_model(dummy_img)

You can see in the models/FlowNet{S/C}.py the different constructors you can call at then end of the files
https://github.com/ClementPinard/FlowNetPytorch/blob/master/models/FlowNetS.py#L95
https://github.com/ClementPinard/FlowNetPytorch/blob/master/models/FlowNetC.py#L112

@ghost
Copy link

ghost commented Oct 31, 2018

Thank you for the reply. will try it and get back to you soon ..

@ghost
Copy link

ghost commented Nov 5, 2018

this is the code that i have written

import torch
from models import flownets
import cv2
import numpy

filename1 = 'img_1.jpg'
filename2 = 'img_2.jpg'

w =512.
h =384.

oriimg1 = cv2.imread(filename1,cv2.IMREAD_COLOR)
oriimg2 = cv2.imread(filename2,cv2.IMREAD_COLOR)

height, width, depth = oriimg1.shape
imgScale1 = w/width
imgScale0= h/height
newX,newY = oriimg1.shape[1]*imgScale1, oriimg1.shape[0]*imgScale0
newimg = cv2.resize(oriimg1,(int(newX),int(newY)))
cv2.imshow("Show by CV2",newimg)
#cv2.waitKey(0)
cv2.imwrite("resizeimg1.jpg",newimg)

height, width, depth = oriimg2.shape
imgScale1 = w/width
imgScale0= h/height
newX,newY = oriimg2.shape[1]*imgScale1, oriimg2.shape[0]*imgScale0
newimg = cv2.resize(oriimg2,(int(newX),int(newY)))
cv2.imshow("Show by CV2",newimg)
#cv2.waitKey(0)
cv2.imwrite("resizeimg2.jpg",newimg)


c1=cv2.imread('resizeimg1.jpg',cv2.IMREAD_COLOR)
c2=cv2.imread('resizeimg2.jpg',cv2.IMREAD_COLOR)

c1 = torch.from_numpy(numpy.array(c1)).double()
c2 = torch.from_numpy(numpy.array(c2)).double()

input_image=torch.cat((c1,c2),2)
input_image=numpy.expand_dims(input_image,axis=0)
input_image=torch.from_numpy(numpy.array(input_image))
print(input_image.shape)
input_image = input_image.transpose(2,3).transpose(1,2)
print(input_image.shape)
input_image = input_image.type('torch.FloatTensor')


weights = torch.load('/home/prateek/Prateek/FlowNetPytorch/flownets_EPE1.951.pth.tar')
my_model = flownets(data=weights)
my_flow = my_model(input_image)
print(my_flow)
my_flow = torch.Tensor.numpy(my_flow)
cv2.imshow('opticalflow.jpg',my_flow)
cv2.waitKey(0)

i get the error
my_flow = torch.Tensor.numpy(my_flow)
TypeError: descriptor 'numpy' requires a 'torch._C._TensorBase' object but received a 'tuple'

what to do in order to get rid of this ?
please help.
thanks.

@ClementPinard
Copy link
Owner

You need to make the model in eval mode. That way, it won't output multiple scale flow, but only the highest scale.
my_model = flownets(data=weights).eval()

also your code is very convoluted, here is a simplified version :

import torch
from models import flownets
import cv2
import numpy

filename1 = 'img_1.jpg'
filename2 = 'img_2.jpg'

w =512.
h =384.

oriimg1 = cv2.imread(filename1,cv2.IMREAD_COLOR)
oriimg2 = cv2.imread(filename2,cv2.IMREAD_COLOR)

img1_scaled = cv2.resize(oriimg1,(h,w))
cv2.imshow("Show by CV2",newimg)
img2_scaled = cv2.resize(oriimg2,(h,w))

img1_tensor = torch.from_numpy(img1_scaled).float().permute(2,0,1)
img1_tensor = img1_tensor/255 - 0.5

img2_tensor = torch.from_numpy(img2_scaled).float().permute(2,0,1)
img2_tensor = img1_tensor/255 - 0.5

input_image=torch.cat((img1_tensor, img2_tensor)).unsqueeze(0)

weights = torch.load('/home/prateek/Prateek/FlowNetPytorch/flownets_EPE1.951.pth.tar')
my_model = flownets(data=weights).eval()
my_flow = my_model(input_image)
print(my_flow)
my_flow = my_flow[0].detach().numpy()
cv2.imshow('opticalflowu.jpg',my_flow[0])
cv2.imshow('opticalflowv.jpg',my_flow[1])
cv2.waitKey(0)

You can note that no cuda is used for the moment, this will then be a bit slow, you can convert your input and the model to cuda with the .to(torch.device('cuda')) function, and then bring the result back to cpu with the .cpu() function, before converting to numpy.

Also don't forget that optical flow is 2D, so opencv won't be able to show it, since it can only show grayscale (1D data) or color (3D data), hence the two imshow at the end. If you want a unified flow map window, you can have a look at that function that I designed, it's essentially a YUV->RGB converter that maps flow to U and V and puts Y to 1.

@ghost
Copy link

ghost commented Nov 10, 2018

Sir,
the flow which we get is of shape (2,128,96). how do i upsample it to get the input image size.
i have gone through the code in main.py in which i found this.

# compute output
output = model(input)
if args.sparse:
    # Since Target pooling is not very precise when sparse,
    # take the highest resolution prediction and upsample it instead of downsampling target
    h, w = target.size()[-2:]
    output = [F.interpolate(output[0], (h,w)), *output[1:]]**

what does this code do? could you please elaborate it.

one last doubt, could you please explain me why we are only considering my_flow[0] but not the whole tensor? sorry for the inconvenience caused.

@ClementPinard
Copy link
Owner

ClementPinard commented Nov 10, 2018

Some insight for you to read : #23 #27
The main idea behind this is that downsampling sparse flow map is not easy because you can't do interpolation. Se we upsample the flow map to compare it to the not downampled target. It is not done when target is dense in order to be consistent with the original paper.

As for my_flow[0], the main idea is that in pytorch everything is a batch. Even if you have only one element, you must construct batch of one element. As a consequence when taking the output of your network, if you want to work with the flow map of your only element, you must use the first element of your batch.

@ghost
Copy link

ghost commented Nov 10, 2018

Thank you, sir, for these insights.
I will get back to you again if I have any doubt.

@ghost
Copy link

ghost commented Nov 14, 2018

Respected Sir,
i have tried to verify the output flow. i have to get the warped image using the code below, but failed. Could you please tell me what is wrong in the code.

import torch
from models import flownets
import cv2
import numpy
import torch.nn.functional as F

filename1 = 'img_1.jpg'
filename2 = 'img_2.jpg'

w =512.
h =384.

oriimg1 = cv2.imread(filename1,cv2.IMREAD_COLOR)
print(oriimg1.shape)
oriimg2 = cv2.imread(filename2,cv2.IMREAD_COLOR)

img1_scaled = cv2.resize(oriimg1,(int(w),int(h)))
#cv2.imshow("Show by CV2",newimg)
img2_scaled = cv2.resize(oriimg2,(int(w),int(h)))
print(img1_scaled.shape)

img1_tensor = torch.from_numpy(img1_scaled).float().permute(2,0,1)
img1_tensor = (img1_tensor/255) - 0.5


img2_tensor = torch.from_numpy(img2_scaled).float().permute(2,0,1)
img2_tensor = (img1_tensor/255) - 0.5

input_image=torch.cat((img1_tensor, img2_tensor)).unsqueeze(0)

weights = torch.load('/home/prateek/Prateek/FlowNetPytorch/flownets_EPE1.951.pth.tar')

my_model = flownets(data=weights).eval()
my_flow = my_model(input_image)


my_flow = F.interpolate(my_flow, size=(int(h), int(w)), mode='bilinear').permute(0,2,3,1)

img1_tensor=img1_tensor.unsqueeze(0)

verify_output= F.grid_sample(img1_tensor, my_flow, mode='bilinear')

verify_output= verify_output[0].detach().numpy().reshape(384,512,3)

cv2.imshow('aaa',verify_output)
cv2.waitKey(0)

@ClementPinard
Copy link
Owner

ClementPinard commented Nov 14, 2018

can you paste the stack trace ? What fails ? Can you visualize the flow ?

Also remeber than grid_sample takes explicit coordinates, not relative (that means you have to add the identity matrix to flow, otherwise you would sample the point at (0,0) with a null flow instead of sampling it at (i,j). Besides, the grid_sample wants values between 1 and -1 relative to image boundaries. everything outside will give you gray colors (the color (0,0,0))

so before sampling you image from flow, you need to

  1. add the identity matrix. You can construct it with torch.arange and little bit of broadcasting, see here for an example
  2. divide and shift everything so that the point (0,0) is now (-1,-1) and the point (H-1, W-1) is now (1,1), see here for another example

I recommend you to read the documentation on grid sample here

Finally, remeber that grid sample is an inverse warp. That means with an optical flow from Img1 to Img2, you can reconstruct Img1 from Img2, and not the other way around, you want to make a grid_sample from img2_tensor and not from img1_tensor

@ghost
Copy link

ghost commented Nov 14, 2018

there is no problem with the syntax as such but i get this as the output in the terminal:
(3968, 2976, 3)
(384, 512, 3)
/usr/local/lib/python2.7/dist-packages/torch/nn/functional.py:1961: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
"See the documentation of nn.Upsample for details.".format(mode))

screenshot from 2018-11-14 17-02-24

i get the above image as the output

@ghost
Copy link

ghost commented Nov 15, 2018

Respected Sir,
Thank you for these insights. actually, my aim is to get, frame 2 from frame 1. Thank you for making me understand that grid_sample gets us the inverse warp.
I want to get future frames from the present frame.
i have tried this code,

**import torch
from models import flownets
import cv2
import numpy
import torch.nn.functional as F

filename1 = 'img_1.jpg'
filename2 = 'img_2.jpg'

w =512.
h =384.

oriimg1 = cv2.imread(filename1,cv2.IMREAD_COLOR)

oriimg2 = cv2.imread(filename2,cv2.IMREAD_COLOR)

img1_scaled = cv2.resize(oriimg1,(int(w),int(h)))
img2_scaled = cv2.resize(oriimg2,(int(w),int(h)))

img1_tensor = torch.from_numpy(img1_scaled).float().permute(2,0,1)
img1_tensor = (img1_tensor/255) - 0.5

img2_tensor = torch.from_numpy(img2_scaled).float().permute(2,0,1)
img2_tensor = (img1_tensor/255) - 0.5

input_image=torch.cat((img1_tensor, img2_tensor)).unsqueeze(0)

weights = torch.load('/home/prateek/Prateek/FlowNetPytorch/flownets_EPE1.951.pth.tar')

my_model = flownets(data=weights).eval()
my_flow = my_model(input_image)

my_flow = F.interpolate(my_flow, size=(int(h), int(w)), mode='bilinear',align_corners=True).permute(0,2,3,1)
my_flow = my_flow[0].detach().numpy()

mapx= numpy.zeros((384,512,1))
mapy= numpy.zeros((384,512,1))

height=my_flow.shape[0]
width=my_flow.shape[1]
R2 = numpy.dstack(numpy.meshgrid(numpy.arange(width), numpy.arange(height)))
pixel_map = R2 + my_flow

cv2.convertMaps(pixel_map[:,:,0],pixel_map[:,:,1], cv2.CV_32FC1 && cv2.CV_32FC1,mapx,mapy, nninterpolation=False)

new_frame = cv2.remap(img1_scaled, pixel_map[:,:,0],pixel_map[:,:,1],interpolation=cv2.INTER_LINEAR)**

BUT i get this error,
Traceback (most recent call last):
File "run_flownetS.py", line 53, in
cv2.convertMaps(pixel_map[:,:,0],None, cv2.CV_32FC1,mapx, nninterpolation=False)
cv2.error: OpenCV(3.4.3) /io/opencv/modules/imgproc/src/imgwarp.cpp:1845: error: (-215:Assertion failed) (m1type == CV_16SC2 && (nninterpolate || m2type == CV_16UC1 || m2type == CV_16SC1)) || (m2type == CV_16SC2 && (nninterpolate || m1type == CV_16UC1 || m1type == CV_16SC1)) || (m1type == CV_32FC1 && m2type == CV_32FC1) || (m1type == CV_32FC2 && m2->empty()) in function 'convertMaps'

Could you please help me resolve this error.
Could you please suggest me a function to get the warped image using optical flow. please.

@ClementPinard
Copy link
Owner

I don't know about your problem which seems to be related with cv2 only.
However, remap is actually the same thing as inverse warp.

Sorry, there is no trivial way to make a direct warp, you will likely end up with ambiguity on occlusion zones: how to decide which point comes at front and overwrite the other one ?
It is important to note that if inverse warping is a function, on which every pixel of the source has a value, it's not the same thing for direct warping.

Best advice is probably to compute inverse flow (img2 to img1) if you want to warp Img1 into Img2

@ghost
Copy link

ghost commented Nov 15, 2018

Respected Sir,
Thank you for this suggestion. after getting the optical flow from img2 to img1. doing an inverse warp would definitely for me. could you please help to get rid of the error:
#25 (comment)

which function is better to use for inverse warp? grid_sample or remap?
this would really help me a lot. I kindly request you to please guide me through this issue.

looking forward to your reply Sir.
Thank you.

@ClementPinard
Copy link
Owner

The two function are the same. Grid_sample can be used with cuda though.

@ghost
Copy link

ghost commented Nov 16, 2018

Respected Sir,
i have tried to use grid_sample but failed. could you please help me by providing me with the exact code for warping the image, which can be used with my code. As i am a beginner it is taking quite some time.

@ghost
Copy link

ghost commented Nov 16, 2018

Sir could you please explain what is proj_c2p_rot & proj_c2p_tr in:

https://github.com/ClementPinard/SfmLearner-Pytorch/blob/61fc23eed1f75ea856e01369c9212a1315196567/inverse_warp.py#L43

and could you please elaborate on what is intrinsic_inv

@ClementPinard
Copy link
Owner

ClementPinard commented Nov 16, 2018

These are not something you should worry about, it's the parti of the inverse warp function that compute the pixel coordinates from depth + 6degrees of freedom pose. For your usecase, you already have the coordinates.

All you need to do is explained in this comment (above) :

so before sampling you image from flow, you need to

  1. add the identity matrix. You can construct it with torch.arange and little bit of broadcasting, see here for an example

  2. divide and shift everything so that the point (0,0) is now (-1,-1) and the point (H-1, W-1) is now (1,1), see here for another example

You already had it with opencv2 when you tried to use numpy.meshgrid and cv2.remap, the workflow is the exact same, just the functions are different, and the grid sample takes normalized coordinates instead of pixel coordinates.

Good luck !

@ghost
Copy link

ghost commented Nov 23, 2018

Sir,
how would we get the img 1 from img 2 using the flow obtained?

Sir, i have checked the flow values. why are the flow values really small?
my idea of flow is that:
let us assume that we found flow from img 1 to img 2
suppose if we pick up some random flow value located at [x,y] . let the flow value at this point be [u,v].
suppose we create an empty numpy array of the same size of the input images.
let a= x+u and b= y+v
now if we take the pixel located at [a,b] in img 2 and place it in the empty array at [x,y] we get back the img 1.

is this understanding right or wrong ? kindly correct me if i am wrong.

@ClementPinard
Copy link
Owner

The values are usually divided by 20, this is a legacy coming from original author's code.
And yes your understanding of inverse warp is correct., you just have to do that for every pixel x,y of your img1 : reconstruct img1 with colors sampled from img2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants