In [1]:
from __future__ import print_function

In [2]:
import urllib
import bz2
import os
import numpy as np
import torch
from torch import autograd

%matplotlib inline
import matplotlib.pyplot as plt

In [3]:
device = torch.device('cpu')

In [4]:
BASE_URL = "http://grail.cs.washington.edu/projects/bal/data/ladybug/"
FILE_NAME = "problem-49-7776-pre.txt.bz2"
URL = BASE_URL + FILE_NAME

In [5]:
if not os.path.isfile(FILE_NAME):
    urllib.request.urlretrieve(URL, FILE_NAME)

In [35]:
def read_bal_data(file_name):
    with bz2.open(file_name, "rt") as file:
        n_cameras, n_points, n_observations = map(int, file.readline().split())
    
        camera_indices = np.empty(n_observations, dtype = int)
        point_indices = np.empty(n_observations, dtype = int)
        points_2d = torch.empty(n_observations, 2, device = device)

        for i in range(n_observations):
            camera_index, point_index, x, y = file.readline().split()
            camera_indices[i] = camera_index
            point_indices[i] = point_index
            points_2d[i] = torch.tensor([float(x), float(y)])

            camera_params = torch.empty(n_cameras*9, device = device)

        for i in range(n_cameras*9):
            camera_params[i] = float(file.readline())

        camera_params = camera_params.view(n_cameras, -1)

        points_3d = torch.empty(n_points*3, device = device)

        for i in range(n_points*3):
            points_3d[i] = float(file.readline())
        points_3d = points_3d.view(n_points, -1)
            
    return camera_params, points_3d, camera_indices, point_indices, points_2d

In [48]:
c_params, p3d, camera_indices, point_indices, points_2d = read_bal_data(FILE_NAME)

In [49]:
camera_indices = torch.tensor(camera_indices, device = device)
point_indices = torch.tensor(point_indices, device = device)

In [50]:
n_cameras = c_params.size()[0]
n_points = p3d.size()[0]

n = 9*n_cameras + 3*n_points
m = 2*points_2d.size()[0]

print("n_cameras: {}".format(n_cameras))
print("n_points: {}".format(n_points))
print("Total number of parameters: {}".format(n))
print("Total number of residuals: {}".format(m))

n_cameras: 49
n_points: 7776
Total number of parameters: 23769
Total number of residuals: 63686


In [51]:
c_params.requires_grad_(True)
p3d.requires_grad_(True)

tensor([[-0.6120,  0.5718, -1.8471],
        [ 1.7075,  0.9539, -6.8772],
        [-0.3734,  1.5359, -4.7824],
        ...,
        [-0.6642, -0.1351, -5.5425],
        [-0.8193,  0.0765, -4.5143],
        [-0.7480,  0.0371, -4.8132]], requires_grad=True)

In [24]:
def rotate(points, rot_vecs):
    
    theta = torch.norm(rot_vecs, dim = 1, keepdim=True)
    v = rot_vecs/theta
#         v[v != v] = 0.
#     print(v.size(), points.size())
    
    dot = torch.sum(points*v, dim = 1, keepdim = True)
    
    cos_theta = torch.cos(theta)
    sin_theta = torch.sin(theta)
    
    ans = cos_theta*points + sin_theta*torch.cross(v, points) + dot*(1-cos_theta)*v
    
    return ans

In [12]:
def project(points, camera_params):
    R = torch.index_select(camera_params, 1, torch.tensor([0,1,2]))
    T = torch.index_select(camera_params, 1, torch.tensor([3,4,5]))
    # print(R.size(), T.size())
    points_proj = rotate(points, R)
    points_proj = points_proj + T
    denom = torch.index_select(points_proj,1,torch.tensor([2])).view(-1,1)
    points_proj_2 = -torch.index_select(points_proj,1,torch.tensor([0,1]))/denom
#     f = camera_params[:,6]
#     k1 = camera_params[:,7]
#     k2 = camera_params[:,8]
    f = torch.index_select(camera_params, 1, torch.tensor([6]))
    k1 = torch.index_select(camera_params, 1, torch.tensor([7]))
    k2 = torch.index_select(camera_params, 1, torch.tensor([8]))
    
    n = torch.sum(torch.mul(points_proj_2,points_proj_2), dim = 1)
    # print(f.size(), k1.size(), k2.size(), n.size())
    r = 1 + torch.mul(n,k1.view(-1)) + torch.mul(k2.view(-1),torch.mul(n,n))
    # print(r.size(), f.size(), points_proj.size())
    # print(torch.mul(r,f.view(f.numel())).size())
    points_proj_3 = points_proj_2*torch.mul(r,f.view(-1)).unsqueeze(1)
    return points_proj_3

In [13]:
def fun(camera_params, points_3d, n_cameras, n_points, camera_indices, point_indices, points_2d):
#     cp = params[:n_cameras*9].view(n_cameras, 9)
#     p3d = params[n_cameras*9:].view(n_points, 3)
#     points_proj = project(p3d[point_indices], cp[camera_indices])
    points_3d_2 = torch.index_select(points_3d, 0, point_indices)
    camera_params_2 = torch.index_select(camera_params, 0, camera_indices)
    points_proj = project(points_3d_2, camera_params_2)
    ans = points_proj - points_2d
    return ans.view(-1)

In [46]:
f0 = fun(camera_params, points_3d, n_cameras, n_points, camera_indices, point_indices, points_2d)

In [47]:
f = fun(camera_params, points_3d, n_cameras, n_points, camera_indices, point_indices, points_2d)

In [36]:
loss = f.pow(2).sum()

In [54]:
points_3d = torch.index_select(p3d, 0, point_indices)
camera_params = torch.index_select(c_params, 0, camera_indices)

In [55]:
R = torch.index_select(camera_params, 1, torch.tensor([0,1,2]))
T = torch.index_select(camera_params, 1, torch.tensor([3,4,5]))
theta = torch.norm(R, dim = 1, keepdim=True)
v = R/theta
dot = torch.sum(points_3d*v, dim = 1, keepdim = True)

cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)

points_proj = cos_theta*points_3d + sin_theta*torch.cross(v, points_3d) + dot*(1-cos_theta)*v
points_proj = points_proj + T
denom = torch.index_select(points_proj,1,torch.tensor([2])).view(-1,1)
points_proj = -torch.index_select(points_proj,1,torch.tensor([0,1]))/denom

f = torch.index_select(camera_params, 1, torch.tensor([6]))
k1 = torch.index_select(camera_params, 1, torch.tensor([7]))
k2 = torch.index_select(camera_params, 1, torch.tensor([8]))

n = torch.sum(torch.mul(points_proj,points_proj), dim = 1)
r = 1 + torch.mul(n,k1.view(-1)) + torch.mul(k2.view(-1),torch.mul(n,n))
points_proj = points_proj*torch.mul(r,f.view(-1)).unsqueeze(1)


In [56]:
f = (points_proj - points_2d).view(-1)

In [57]:
loss = f.pow(2).sum()

In [58]:
loss.backward()

In [61]:
print(p3d.grad)

tensor([[-2.6374e+04,  1.7635e+04,  3.0804e+04],
        [ 1.0622e+03,  7.0126e+02,  5.4114e+02],
        [-1.4381e+04,  3.6268e+04,  3.6153e+04],
        ...,
        [ 2.0478e-01, -1.0027e+01,  3.9044e-01],
        [ 8.7769e-01,  6.0243e+00, -2.4398e-01],
        [-1.1459e+00, -1.1489e+01,  2.0699e-01]])


In [52]:
lr=1e-6

In [53]:
for i in range(1):
    f = fun(c_params, p3d, n_cameras, n_points, camera_indices, point_indices, points_2d)
    
    loss = f.pow(2).sum()
    
    print(i, " --> ", loss.item())
    loss.backward()
    
    
    
    with torch.no_grad():
        print(c_params)
        c_params -= lr*c_params.grad
        p3d -= lr*p3d.grad
        print(c_params)
        c_params.grad.zero_()
        p3d.grad.zero_()

0  -->  1701824.875
tensor([[ 1.5742e-02, -1.2791e-02, -4.4008e-03, -3.4094e-02, -1.0751e-01,
          1.1202e+00,  3.9975e+02, -3.1771e-07,  5.8820e-13],
        [ 1.5977e-02, -2.5224e-02, -9.4001e-03, -8.5668e-03, -1.2188e-01,
          7.1901e-01,  4.0202e+02, -3.7805e-07,  9.3074e-13],
        [ 1.4335e-02, -2.8132e-03, -6.4099e-03, -3.6518e-02, -9.8322e-02,
          1.3142e+00,  3.9945e+02, -3.1712e-07,  5.4981e-13],
        [ 1.4846e-02, -2.1063e-02, -1.1669e-03, -2.4951e-02, -1.1398e-01,
          9.2166e-01,  4.0040e+02, -3.2953e-07,  6.7329e-13],
        [ 1.4383e-02,  1.4437e-03, -6.3331e-03, -4.6798e-02, -9.0595e-02,
          1.5019e+00,  3.9934e+02, -3.2059e-07,  5.3774e-13],
        [ 1.2547e-02, -2.0898e-02, -6.3758e-03,  1.1686e-02, -1.2684e-01,
          5.1416e-01,  4.0254e+02, -3.8581e-07,  1.0499e-12],
        [ 1.3596e-02,  4.5686e-03, -6.1368e-03, -5.6648e-02, -8.3145e-02,
          1.6830e+00,  3.9895e+02, -3.0134e-07,  4.3330e-13],
        [ 1.5239e-02, -1.797

In [60]:
current_f = loss.grad_fn
print(current_f)
while True:
    current_f = current_f.next_functions[0][0]
    print(current_f)

<SumBackward0 object at 0x7f77604d8ad0>
<PowBackward0 object at 0x7f77537cac10>
<ViewBackward object at 0x7f7753dfba90>
<SubBackward0 object at 0x7f7753807c90>
<MulBackward0 object at 0x7f7753812150>
<DivBackward0 object at 0x7f77535418d0>
<NegBackward object at 0x7f7753fd4bd0>
<IndexSelectBackward object at 0x7f7753541d50>
<AddBackward0 object at 0x7f7753812590>
<AddBackward0 object at 0x7f77537cac10>
<AddBackward0 object at 0x7f7753dfb2d0>
<MulBackward0 object at 0x7f7753dfbe50>
<CosBackward object at 0x7f7753fd4bd0>
<NormBackward3 object at 0x7f7753fd4850>
<IndexSelectBackward object at 0x7f7753fd4ad0>
<IndexSelectBackward object at 0x7f7753fcb690>
<AccumulateGrad object at 0x7f7753fcb090>


IndexError: tuple index out of range