In [None]:
%load_ext autoreload
%autoreload 2

import torch
from torch import nn
from torch.nn import functional as F

In [None]:
quat = torch.tensor(
    [
        [0, 1, 0, 0], 
        [1, 1, 1, 1]
    ],
dtype=torch.float32)

unit_quats = F.normalize(quat)
print(unit_quats)

In [None]:
class PosePrediction(nn.Module):
    def __init__(self, ):
        super().__init__()
        
        self.backbone = torch.hub.load('pytorch/vision:v0.5.0', 'resnet18', pretrained=True)
        self.backbone = nn.Sequential(*(list(self.backbone.children())[:-2]))
        
        self.predict = nn.Sequential(
            *[
                nn.Conv2d(1024, 1024, 3),
                nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
                
                nn.Conv2d(1024, 1024, 3),
                nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
                
                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            ]
        )

        self.fc = nn.Sequential(
            *[
                nn.Linear(1024, 120),
                nn.Linear(120, 84)
            ]
        )
        
        self.fc_quat = nn.Linear(84, 4)
        self.fc_trns = nn.Linear(84, 3)
        
    def forward(self, img1, img2):
        
        feat_img1 = self.backbone(img1)
        feat_img2 = self.backbone(img2)
        
        x = torch.cat((feat_img1, feat_img2), dim=1)
        
        x = self.predict(x)
        
#         x = x.view([-1, 1024])
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        pred_quat = self.fc_quat(x)        
        pred_trns = self.fc_trns(x)
        
        return pred_quat, pred_trns

#     def extra_repr(self):
#         return f'(backbone): {self.backbone}'

In [None]:
model = PosePrediction()

sample_input1 = torch.rand(
    (2, 3, 224, 224)
) 

sample_input2 = torch.rand(
    (2, 3, 224, 224)
) 

pred_quat, pred_trns = model(sample_input1, sample_input2)

In [None]:
model