You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi @amdegroot , I was trying to get the demo running and I'm having a problem when calling transform(img) of the BaseTransform class.
When doing python test.py the output is the following.
Finished loading model!
Testing image 1/4952....
Traceback (most recent call last):
File "test.py", line 84, in <module>
thresh=args.visual_threshold)
File "test.py", line 39, in test_net
x = Variable(transform(img).unsqueeze(0))
File "/home/arian/Documents/proyecto-integrador/models/ssd/ssd-pytorch/data/data_augment.py", line 119, in __call__
return torch.Tensor(img)
RuntimeError: tried to construct a tensor from a nested float sequence, but found an item of type numpy.float32 at index (0, 0, 0)
This happens in the demo notebook and in the test.py file.
Do you have any idea why this could be happening?.
Thanks,
Arian.
The text was updated successfully, but these errors were encountered:
Ok sorry for leaving an issue before trying everything I could to fix it.
I don't use torch so I had no idea what to do, but here's how I fixed it just in case.
On data_augment.py I had to chage return torch.Tensor(img) to return torch.FloatTensor(img).
Then an error appear saying it was expecting a CPU tensor and got a CUDA tensor instead, so I had to add .cuda() to the input variables (I don't know if there is another way to do it, I don't really know how to use pytorch yet). For example, changing x = Variable(transform(img).unsqueeze(0)) to x = Variable(transform(img).unsqueeze(0).cuda())
Ah, thanks for pointing that out, it must only be the case on GPU. For some reason I was using torch.Tensor() instead of torch.from_numpy() in data_augment.py. It should all be fixed now. Feel free to send a Pull Request if you find any other errors!
Hi @amdegroot , I was trying to get the demo running and I'm having a problem when calling
transform(img)
of theBaseTransform
class.When doing
python test.py
the output is the following.This happens in the demo notebook and in the test.py file.
Do you have any idea why this could be happening?.
Thanks,
Arian.
The text was updated successfully, but these errors were encountered: