# Imports

In [15]:
from backbone.repvgg import get_RepVGG_func_by_name
from backbone.repvgg import repvgg_model_convert
from backbone.efficientnet_lite import build_efficientnet_lite
import numpy as np
import torch
import torch.nn as nn
import torch.onnx as onnx
modelloc = "./model.tar"
pthlloc = "./model.pth"
onnxlloc ="./model.onnx" 


# Utils

In [5]:
def normalize_vector(v):
    batch = v.shape[0]
    v_mag = torch.sqrt(v.pow(2).sum(1))# batch
    gpu = v_mag.get_device()
    if gpu < 0:
        eps = torch.autograd.Variable(torch.FloatTensor([1e-8])).to(torch.device('cpu'))
    else:
        eps = torch.autograd.Variable(torch.FloatTensor([1e-8])).to(torch.device('cuda:%d' % gpu))
    v_mag = torch.max(v_mag, eps)
    v_mag = v_mag.view(batch,1).expand(batch,v.shape[1])
    v = v/v_mag
    return v

# u, v batch*n
def cross_product(u, v):
    batch = u.shape[0]
    #print (u.shape)
    #print (v.shape)
    i = u[:,1]*v[:,2] - u[:,2]*v[:,1]
    j = u[:,2]*v[:,0] - u[:,0]*v[:,2]
    k = u[:,0]*v[:,1] - u[:,1]*v[:,0]

    out = torch.cat((i.view(batch,1), j.view(batch,1), k.view(batch,1)),1) #batch*3

    return out



In [6]:
def compute_rotation_matrix_from_ortho6d(poses):
    x_raw = poses[:,0:3] #batch*3
    y_raw = poses[:,3:6] #batch*3

    x = normalize_vector(x_raw) #batch*3
    z = cross_product(x,y_raw) #batch*3
    z = normalize_vector(z) #batch*3
    y = cross_product(z,x) #batch*3

    x = x.view(-1,3,1)
    y = y.view(-1,3,1)
    z = z.view(-1,3,1)
    matrix = torch.cat((x,y,z), 2) #batch*3*3
    return matrix

# Model Architecture

In [None]:
class SixDRepNet(nn.Module):
    def __init__(self,
                 backbone_name, backbone_file, deploy,
                 pretrained=True):
        super(SixDRepNet, self).__init__()
        repvgg_fn = get_RepVGG_func_by_name(backbone_name)
        backbone = repvgg_fn(deploy)
        if pretrained:
            checkpoint = torch.load(backbone_file)
            if 'state_dict' in checkpoint:
                checkpoint = checkpoint['state_dict']
            ckpt = {k.replace('module.', ''): v for k,
                    v in checkpoint.items()}  # strip the names
            backbone.load_state_dict(ckpt)
            for param in backbone.parameters():
                param.requires_grad = False
        self.layer0, self.layer1, self.layer2, self.layer3, self.layer4 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3, backbone.stage4
        self.gap = nn.AdaptiveAvgPool2d(output_size=1)

        last_channel = 0
        for n, m in self.layer4.named_modules():
            if ('rbr_dense' in n or 'rbr_reparam' in n) and isinstance(m, nn.Conv2d):
                last_channel = m.out_channels

        fea_dim = last_channel

        self.linear_reg = nn.Linear(fea_dim, 6)

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.gap(x)
        x = torch.flatten(x, 1)
        x = self.linear_reg(x)
        return compute_rotation_matrix_from_ortho6d(x)



## EfficientNet backbone

In [9]:
class SixDENet(nn.Module):
    def __init__(self,
                 backbone_name, backbone_file, deploy,
                 pretrained=True):
        super(SixDENet, self).__init__()
        self.backbone = build_efficientnet_lite(backbone_name,1000)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, 6)
        if pretrained and backbone_file:
            self.backbone.load_pretrain(backbone_file)
            self.backbone.eval()


    def forward(self, x):
        #x = self.greyscaletorgb(x)
        x = self.backbone(x)
        return compute_rotation_matrix_from_ortho6d(x)

# Conversions

## Reparameterization
multibranch to single branch

In [None]:
# if running from a 
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
def load_filtered_state_dict(model, snapshot):
    # By user apaszke from discuss.pytorch.org
    model_dict = model.state_dict()
    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
    model_dict.update(snapshot)
    model.load_state_dict(model_dict)
def convert():
    backbone='RepVGG-AZ'

    print('Loading model.')
    model = SixDRepNet(backbone_name=backbone,
                            backbone_file='',
                            deploy=False,
                            pretrained=False)

    # Load snapshot
    saved_state_dict = torch.load(modelloc)

    load_filtered_state_dict(model, saved_state_dict['model_state_dict'])
    print('Converting model.')
    repvgg_model_convert(model, save_path=pthlloc)
    print('Done.')
convert()




Loading model.
RepVGG Block, identity =  None
RepVGG Block, identity =  None
RepVGG Block, identity =  None
RepVGG Block, identity =  BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  None
RepVGG Block, identity =  BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG

## Onnx Conversion
from pytorch to onnx format

In [None]:
# Define your PyTorch model
model = torch.load(pthlloc,map_location='cuda')
dicte = model
modella = SixDRepNet(backbone_name='RepVGG-AZ',
                        backbone_file='',
                        deploy=True,
                        pretrained=False)
modella.load_state_dict(dicte)
modella.eval()

# Sample input (adjust according to your model's input shape)
dummy_input = torch.randn(1, 3, 224, 224)

# Export the model to ONNX
onnx_path = onnxlloc
torch.onnx.export(modella, dummy_input, onnx_path, verbose=True)

In [16]:
# Define your PyTorch model
#'efficientnet_lite0': [1.0, 1.0, 224, 0.2],
#'efficientnet_lite1': [1.0, 1.1, 240, 0.2],
#'efficientnet_lite2': [1.1, 1.2, 260, 0.3],
#'efficientnet_lite3': [1.2, 1.4, 280, 0.3],
#'efficientnet_lite4'
# model = torch.load(pthlloc,map_location='cuda')
# dicte = model
modella = SixDENet(backbone_name='efficientnet_lite3',
                        backbone_file=pthlloc,
                        deploy=True,
                        pretrained=True)
# modella.load_state_dict(dicte)
modella.eval()

# Sample input (adjust according to your model's input shape)
dummy_input = torch.randn(1, 3, 224, 224)

# Export the model to ONNX
onnx_path = onnxlloc
torch.onnx.export(modella, dummy_input, onnx_path, verbose=True, export_params=True)