In [185]:
from __future__ import print_function

In [186]:
import urllib
import bz2
import os
import numpy as np
import torch
from torch import autograd
from torch import optim

%matplotlib inline
import matplotlib.pyplot as plt

In [187]:
device = torch.device('cpu')

In [188]:
BASE_URL = "http://grail.cs.washington.edu/projects/bal/data/ladybug/"
FILE_NAME = "problem-49-7776-pre.txt.bz2"
URL = BASE_URL + FILE_NAME

In [189]:
if not os.path.isfile(FILE_NAME):
    urllib.request.urlretrieve(URL, FILE_NAME)

In [190]:
def read_bal_data(file_name):
    with bz2.open(file_name, "rt") as file:
        n_cameras, n_points, n_observations = map(int, file.readline().split())
    
        camera_indices = np.empty(n_observations, dtype = int)
        point_indices = np.empty(n_observations, dtype = int)
        points_2d = torch.empty(n_observations, 2, device = device)

        for i in range(n_observations):
            camera_index, point_index, x, y = file.readline().split()
            camera_indices[i] = camera_index
            point_indices[i] = point_index
            points_2d[i] = torch.tensor([float(x), float(y)])

            camera_params = torch.empty(n_cameras*9, device = device)

        for i in range(n_cameras*9):
            camera_params[i] = float(file.readline())

        camera_params = camera_params.view(n_cameras, -1)

        points_3d = torch.empty(n_points*3, device = device)

        for i in range(n_points*3):
            points_3d[i] = float(file.readline())
        points_3d = points_3d.view(n_points, -1)
            
    return camera_params, points_3d, camera_indices, point_indices, points_2d

In [172]:
c_params, p3d, camera_indices, point_indices, points_2d = read_bal_data(FILE_NAME)

In [173]:
camera_indices = torch.tensor(camera_indices, device = device)
point_indices = torch.tensor(point_indices, device = device)

In [174]:
n_cameras = c_params.size()[0]
n_points = p3d.size()[0]

n = 9*n_cameras + 3*n_points
m = 2*points_2d.size()[0]

print("n_cameras: {}".format(n_cameras))
print("n_points: {}".format(n_points))
print("Total number of parameters: {}".format(n))
print("Total number of residuals: {}".format(m))

n_cameras: 49
n_points: 7776
Total number of parameters: 23769
Total number of residuals: 63686


In [175]:
c_params.requires_grad_(True)
p3d.requires_grad_(True)

tensor([[-0.6120,  0.5718, -1.8471],
        [ 1.7075,  0.9539, -6.8772],
        [-0.3734,  1.5359, -4.7824],
        ...,
        [-0.6642, -0.1351, -5.5425],
        [-0.8193,  0.0765, -4.5143],
        [-0.7480,  0.0371, -4.8132]], requires_grad=True)

In [176]:
def rotate(points, rot_vecs):
    
    theta = torch.norm(rot_vecs, dim = 1, keepdim=True)
    v = rot_vecs/theta
    v[v != v] = 0.
#     print(v.size(), points.size())
    
    dot = torch.sum(points*v, dim = 1, keepdim = True)
    
    cos_theta = torch.cos(theta)
    sin_theta = torch.sin(theta)
    
    ans = cos_theta*points + sin_theta*torch.cross(v, points) + dot*(1-cos_theta)*v
    
    return ans

In [177]:
def project(points, camera_params):
    R = torch.index_select(camera_params, 1, torch.tensor([0,1,2]))
    T = torch.index_select(camera_params, 1, torch.tensor([3,4,5]))
    # print(R.size(), T.size())
    points_proj = rotate(points, R)
    points_proj = points_proj + T
    denom = torch.index_select(points_proj,1,torch.tensor([2])).view(-1,1)
    points_proj_2 = -torch.index_select(points_proj,1,torch.tensor([0,1]))/denom
#     f = camera_params[:,6]
#     k1 = camera_params[:,7]
#     k2 = camera_params[:,8]
    f = torch.index_select(camera_params, 1, torch.tensor([6]))
    k1 = torch.index_select(camera_params, 1, torch.tensor([7]))
    k2 = torch.index_select(camera_params, 1, torch.tensor([8]))
    
    n = torch.sum(torch.mul(points_proj_2,points_proj_2), dim = 1)
    # print(f.size(), k1.size(), k2.size(), n.size())
    r = 1 + torch.mul(n,k1.view(-1)) + torch.mul(k2.view(-1),torch.mul(n,n))
    # print(r.size(), f.size(), points_proj.size())
    # print(torch.mul(r,f.view(f.numel())).size())
    points_proj_3 = points_proj_2*torch.mul(r,f.view(-1)).unsqueeze(1)
    return points_proj_3

In [178]:
def fun(camera_params, points_3d, n_cameras, n_points, camera_indices, point_indices, points_2d):
    points_3d_2 = torch.index_select(points_3d, 0, point_indices)
    camera_params_2 = torch.index_select(camera_params, 0, camera_indices)
    points_proj = project(points_3d_2, camera_params_2)
    ans = points_proj - points_2d
    return ans.view(-1)

def loss_fn(f):
    "Squared error loss"
    return f.pow(2).sum()

In [179]:
def train(optimizer, epochs, loss_fn, fun):
    for i in range(epochs):
        optimizer.zero_grad()
        
        f = fun(c_params, p3d, n_cameras, n_points, camera_indices, point_indices, points_2d)
        loss = loss_fn(f)
        
        print(i, " --> ", loss.item())
        
        loss.backward()
        optimizer.step()

In [180]:
optimizer = optim.RMSprop([c_params, p3d], lr=1e-3)

In [184]:
train(optimizer, 500, loss_fn, fun)

0  -->  815059.8125
1  -->  815103.1875
2  -->  819497.1875
3  -->  819030.1875
4  -->  823463.3125
5  -->  821461.375
6  -->  824460.0625
7  -->  821475.1875
8  -->  822136.8125
9  -->  817855.0
10  -->  816198.8125
11  -->  811478.1875
12  -->  809390.9375
13  -->  805750.875
14  -->  804590.5
15  -->  802394.3125
16  -->  802200.0
17  -->  800955.0
18  -->  801192.6875
19  -->  800352.0
20  -->  800708.4375
21  -->  800275.0
22  -->  801035.5625
23  -->  801530.8125
24  -->  803155.3125
25  -->  804517.4375
26  -->  806117.875
27  -->  806123.4375
28  -->  805536.375
29  -->  803605.5625
30  -->  802076.375
31  -->  800600.4375
32  -->  800103.8125
33  -->  799887.4375
34  -->  800604.625
35  -->  801478.125
36  -->  803074.25
37  -->  804864.75
38  -->  807066.3125
39  -->  809692.125
40  -->  812205.1875
41  -->  815502.125
42  -->  817719.4375
43  -->  820884.5625
44  -->  821350.3125
45  -->  822567.4375
46  -->  820666.9375
47  -->  819622.625
48  -->  817180.0
49  -->  815363.

394  -->  768407.5625
395  -->  770336.75
396  -->  770910.1875
397  -->  773744.5625
398  -->  774766.0
399  -->  778302.25
400  -->  780062.1875
401  -->  784199.375
402  -->  786994.375
403  -->  791562.8125
404  -->  795325.125
405  -->  799786.5
406  -->  803960.8125
407  -->  807246.3125
408  -->  810719.9375
409  -->  811570.0625
410  -->  813120.375
411  -->  811392.0
412  -->  810880.5
413  -->  807974.6875
414  -->  806128.625
415  -->  803083.3125
416  -->  800344.9375
417  -->  797329.4375
418  -->  794344.0
419  -->  791701.5
420  -->  789368.4375
421  -->  787413.5
422  -->  786142.0
423  -->  784954.6875
424  -->  784744.0
425  -->  784425.875
426  -->  785185.75
427  -->  785805.0
428  -->  787238.875
429  -->  788247.25
430  -->  789855.125
431  -->  790527.4375
432  -->  792409.3125
433  -->  792477.0
434  -->  794732.0625
435  -->  794041.8125
436  -->  796174.0
437  -->  794822.3125
438  -->  796242.875
439  -->  794009.9375
440  -->  794138.5625
441  -->  791086.75