# Definitions

In [1]:
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd import Variable
from torch import FloatTensor
import torch.optim as optim
from torch import Tensor
import torch
from torch import nn
from torch.nn import Parameter

import numpy as np

from pyquaternion import Quaternion

# numpy implementation of yi zhou's method
def norm(v):
    return v/np.linalg.norm(v)

def gs(M):
    a1 = M[:,0]
    a2 = M[:,1]
    b1 = norm(a1)
    b2 = norm((a2-np.dot(b1,a2)*b1))
    b3 = np.cross(b1,b2)
    return np.vstack([b1,b2,b3]).T

# input sz bszx3x2
def bgs(d6s):
    bsz = d6s.shape[0]
    b1 = F.normalize(d6s[:,:,0], p=2, dim=1)
    a2 = d6s[:,:,1]  
    b2 = F.normalize(a2-torch.bmm(b1.view(bsz,1,-1),a2.view(bsz,-1,1)).view(bsz,1)*b1,p=2,dim=1)    
    b3=torch.cross(b1,b2,dim=1)
    return torch.stack([b1,b2,b3],dim=1).permute(0,2,1)

class geodesic_loss_R(nn.Module):
    def __init__(self,reduction='mean'):
        super().__init__()
        self.reduction = reduction
        self.eps = 1e-6

    # batch geodesic loss for rotation matrices
    def bgdR(self,Rgts,Rps):
        Rds = torch.bmm(Rgts.permute(0,2,1),Rps)
        Rt = torch.sum(Rds[:,torch.eye(3).byte()],1) #batch trace
        # necessary or it might lead to nans and the likes
        theta = torch.clamp(0.5*(Rt-1), -1+self.eps, 1-self.eps)
        return torch.acos(theta)

    def forward(self, ypred, ytrue):
        theta = self.bgdR(ypred,ytrue)
        if self.reduction == 'mean':
            return torch.mean(theta)
        else:
            return theta



# Mini training loop setup

In [4]:
np.random.seed(3434)

R1 = Quaternion.random().rotation_matrix
R2 = Quaternion.random().rotation_matrix

Rgts = np.stack([R1,R2]).astype(np.float32)
d6s = np.random.uniform(-0.5,0.5,size=(2,3,2)).astype(np.float32)

Rgts = torch.from_numpy(Rgts) # gt rotations
d6s = torch.from_numpy(d6s) # random 6d representation

In [5]:
geodesic_loss = geodesic_loss_R()
P = Parameter(FloatTensor(d6s), requires_grad = True)
optimizer = optim.Adam([P],0.01)
Rgts = Variable(Rgts)

steps = 0
while True:
    Rgs = bgs(P)
    loss = geodesic_loss(Rgs,Rgts)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    steps+=1
    print('%d %s'%(steps,loss.item()))
    if loss.item()<0.01:
        print('done')
        break



1 2.804333209991455
2 2.7662675380706787
3 2.7275469303131104
4 2.6883912086486816
5 2.649040460586548
6 2.609743595123291
7 2.570744514465332
8 2.5322651863098145
9 2.4944915771484375
10 2.457561492919922
11 2.421558380126953
12 2.3865036964416504
13 2.352360963821411
14 2.3190412521362305
15 2.2864246368408203
16 2.2543752193450928
17 2.22275447845459
18 2.1914305686950684
19 2.1602909564971924
20 2.129246234893799
21 2.09822940826416
22 2.0671892166137695
23 2.0360822677612305
24 2.0048670768737793
25 1.9735008478164673
26 1.9419374465942383
27 1.910127878189087
28 1.8780207633972168
29 1.845564842224121
30 1.8127119541168213
31 1.779419183731079
32 1.7456527948379517
33 1.711390495300293
34 1.6766245365142822
35 1.6413638591766357
36 1.605635404586792
37 1.5694844722747803
38 1.5329723358154297
39 1.4961738586425781
40 1.459172248840332
41 1.422052025794983
42 1.3848921060562134
43 1.347761631011963
44 1.3107181787490845
45 1.2738103866577148
46 1.237086296081543
47 1.2006024122238



In [6]:
bgs(d6s) # optimized rotations

tensor([[[ 0.4163, -0.8772,  0.2393],
         [ 0.3215,  0.3882,  0.8637],
         [-0.8505, -0.2826,  0.4436]],

        [[ 0.6883, -0.7110,  0.1439],
         [-0.5411, -0.3711,  0.7547],
         [-0.4831, -0.5973, -0.6401]]])

In [7]:
Rgts # gt rotations

tensor([[[ 0.4173, -0.8735,  0.2509],
         [ 0.3225,  0.4004,  0.8577],
         [-0.8496, -0.2770,  0.4488]],

        [[ 0.6903, -0.7131,  0.1225],
         [-0.5312, -0.3845,  0.7550],
         [-0.4913, -0.5862, -0.6442]]])