In [1]:
from __future____ import print_function

In [2]:
import os
import numpy as np
import torch
from torch import nn

In [3]:
class Sparse3dBA(nn.Module):
    """A simple two-view pose+depth estimator based on iterative optimization.
    Optimizes over the rotation R, translation t, and optionally depth z0.
    Convention: the poses is from camera 0 to camera 1.
    Warning: not yet batched, and not fully tested.
    """
    def __init__(self, iterations, loss_fn=squared_loss, lambda_=0.01,
                 opt_depth=True, verbose=False):
        super().__init__()
        self.iterations = iterations
        self.loss_fn = loss_fn
        self.verbose = verbose
        self.lambda_ = lambda_

    def forward(self, pts0, pts1, confidence=None, scale=None,
                z0_gt=None, R_gt=None, t_gt=None):
        R = torch.eye(3).to(pts0)
        t = pts0.new_tensor([1, 1, 0]).float()
        z0 = torch.ones_like(pts0[..., 0])

        pts0_h = to_homogeneous(pts0)
        lambda_ = self.lambda_
        if not self.opt_depth:
            assert z0_gt is not None
            z0 = z0_gt

        for i in range(self.iterations):
            p_3d_1 = (pts0_h * z0[..., None]) @ R.T + t
            p_proj_1 = from_homogeneous(p_3d_1)
            error = p_proj_1 - pts1
            cost = (error**2).sum(-1)

            cost, weights, _ = scaled_loss(
                cost, self.loss_fn, scale[..., None])
            if confidence is not None:
                weights = weights * confidence
                cost = cost * confidence
            if i == 0:
                prev_cost = cost.mean(-1)
            if self.verbose:
                print('Iter ', i, cost.mean().item())

            J_p_T = torch.cat([
                batched_eye_like(p_3d_1, 3), -skew_symmetric(p_3d_1)], -1)
            J_p_d = R @ pts0_h[..., None]

            shape = p_3d_1.shape[:-1]
            o, z = p_3d_1.new_ones(shape), p_3d_1.new_zeros(shape)
            J_e_p = torch.stack([
                o, z, -p_3d_1[..., 0] / p_3d_1[..., 2],
                z, o, -p_3d_1[..., 1] / p_3d_1[..., 2],
            ], dim=-1).reshape(shape+(2, 3)) / p_3d_1[..., 2, None, None]

            J_e_T = J_e_p @ J_p_T
            J_e_d = (J_e_p @ J_p_d).squeeze(-1)

            Grad = torch.einsum('...ijk,...ij->...ik', J_e_T, error)
            Grad = weights[..., None] * Grad
            Grad = Grad.sum(-2)  # Grad was ... x N x 6
            if self.opt_depth:
                Grad_depth = torch.sum(J_e_d * error, -1)
                Grad = torch.cat([Grad, weights * Grad_depth], -1)

            if self.opt_depth:
                # TODO: the hessian could be built more efficiently
                J = torch.cat([
                    J_e_T,
                    J_e_d.transpose(-1, -2).diag_embed(dim1=-3, dim2=-1)], -1)
            else:
                J = J_e_T
            Hess = torch.einsum('...ijk,...ijl->...ikl', J, J)
            Hess = weights[..., None, None] * Hess
            Hess = Hess.sum(-3)  # Hess was ... x N x 6 x 6

            delta = optimizer_step(Grad, Hess, lambda_)
            if torch.isnan(delta).any():
                logging.warning('NaN detected, exit')
                break
            dt, dw, dd = delta[..., :3], delta[..., 3:6], delta[6:]
            dr = so3exp_map(dw)
            R_new = dr @ R
            t_new = dr @ t + dt
            z0_new = torch.max(
                z0 + dd, z0.new_tensor(1e-5)) if self.opt_depth else z0

            new_error = from_homogeneous(
                (pts0_h * z0_new[..., None]) @ R_new.T + t_new) - pts1
            new_cost = (new_error**2).sum(-1)
            new_cost = scaled_loss(new_cost, self.loss_fn, scale[..., None])[0]
            new_cost = (confidence*new_cost).mean(-1)

            lambda_ = np.clip(lambda_ * (10 if new_cost > prev_cost else 1/10),
                              1e-6, 1e2)
            if new_cost > prev_cost:  # cost increased
                continue
            prev_cost = new_cost
            R, t, z0 = R_new, t_new, z0_new

            if z0_gt is not None and R_gt is not None and t_gt is not None:
                gt_scale = (z0_gt / z0).abs().mean()
                if self.verbose:
                    gt_error = pose_error(R, t * gt_scale, R_gt, t_gt)
                    print('Pose error:', *gt_error)


NameError: name 'squared_loss' is not defined