In [45]:
# %load resnet18.py
from torch import nn
import torch
from torch.nn import functional as F
import numpy as np
from resnet import resnet18
from singlestage import SimplePnPNet
from utils import quaternion2rotation


class Resnet18(nn.Module):
    def __init__(self, ver_dim, seg_dim, fcdim=256, s8dim=128, s4dim=64, s2dim=32, raw_dim=32):
        super(Resnet18, self).__init__()

        # Load the pretrained weights, remove avg pool
        # layer and get the output stride of 8
        resnet18_8s = resnet18(fully_conv=True,
                               pretrained=True,
                               output_stride=8,
                               remove_avg_pool_layer=True)

        self.ver_dim=ver_dim
        self.seg_dim=seg_dim

        # Randomly initialize the 1x1 Conv scoring layer
        resnet18_8s.fc = nn.Sequential(
            nn.Conv2d(resnet18_8s.inplanes, fcdim, 3, 1, 1, bias=False),
            nn.BatchNorm2d(fcdim),
            nn.ReLU(True)
        )
        self.resnet18_8s = resnet18_8s

        # x8s->128
        self.conv8s=nn.Sequential(
            nn.Conv2d(128+fcdim, s8dim, 3, 1, 1, bias=False),
            nn.BatchNorm2d(s8dim),
            nn.LeakyReLU(0.1,True)
        )
        self.up8sto4s=nn.UpsamplingBilinear2d(scale_factor=2)
        # x4s->64
        self.conv4s=nn.Sequential(
            nn.Conv2d(64+s8dim, s4dim, 3, 1, 1, bias=False),
            nn.BatchNorm2d(s4dim),
            nn.LeakyReLU(0.1,True)
        )

        # x2s->64
        self.conv2s=nn.Sequential(
            nn.Conv2d(64+s4dim, s2dim, 3, 1, 1, bias=False),
            nn.BatchNorm2d(s2dim),
            nn.LeakyReLU(0.1,True)
        )
        self.up4sto2s=nn.UpsamplingBilinear2d(scale_factor=2)

        self.convraw = nn.Sequential(
            nn.Conv2d(3+s2dim, raw_dim, 3, 1, 1, bias=False),
            nn.BatchNorm2d(raw_dim),
            nn.LeakyReLU(0.1,True),
            nn.Conv2d(raw_dim, seg_dim+ver_dim, 1, 1)
        )
        self.up2storaw = nn.UpsamplingBilinear2d(scale_factor=2)
        
        self.single_stage = SimplePnPNet(4).cuda()

    def _normal_initialization(self, layer):
        layer.weight.data.normal_(0, 0.01)
        layer.bias.data.zero_()


    def forward(self, x, feature_alignment=False):
        x2s, x4s, x8s, x16s, x32s, xfc = self.resnet18_8s(x)

        fm=self.conv8s(torch.cat([xfc,x8s],1))
        fm=self.up8sto4s(fm)
        if fm.shape[2]==136:
            fm = nn.functional.interpolate(fm, (135,180), mode='bilinear', align_corners=False)

        fm=self.conv4s(torch.cat([fm,x4s],1))
        fm=self.up4sto2s(fm)

        fm=self.conv2s(torch.cat([fm,x2s],1))
        fm=self.up2storaw(fm)

        x=self.convraw(torch.cat([fm,x],1))
        seg_pred=x[:,:self.seg_dim,:,:]
        ver_pred=x[:,self.seg_dim:,:,:]
        
        
        #####################################################################
        # single stage model
        #####################################################################
        
        # x refers to horizontal coord and y refers to vertical coord
        # refer to /lib/utils/pvnet/pvnet_data_utils.py 'compute_vertex' function
        pred_dx = []
        pred_dy = []
        pred_x = []
        pred_y = []
        
        ver_pred_reshape = ver_pred.permute(0, 2, 3, 1)
        batch_size, h, w, vn_2 = ver_pred_reshape.shape
        ver_pred_reshape = ver_pred_reshape.reshape(batch_size, h, w, vn_2//2, 2)
        
        batch_seg_mask = torch.argmax(seg_pred, dim=1)
        
        batch_idx_used = []
        for b in range(batch_size):
            seg_mask = batch_seg_mask[b]
            seg_indices = seg_mask.nonzero().cuda()
            # randomly sample 200 points
            if seg_indices.shape[0] >= 200:
                batch_idx_used.append(b)
                
#                 sampled_indices = torch.from_numpy(np.random.choice(range(seg_indices.shape[0]), 200, replace=False)).cuda()
                sampled_indices = torch.arange(200).cuda()
                
                dx = ver_pred_reshape[b, seg_indices[sampled_indices,0], seg_indices[sampled_indices,1], :, 0]
                dx = dx.view(-1,1).squeeze()
                dy = ver_pred_reshape[b, seg_indices[sampled_indices,0], seg_indices[sampled_indices,1], :, 1]
                dy = dy.view(-1,1).squeeze()

                pred_dx.append(dx)
                pred_dy.append(dy)
                
                px = seg_indices[sampled_indices,1].float() / w
                px = px.repeat_interleave(vn_2//2)
                py = seg_indices[sampled_indices,0].float() / h
                py = py.repeat_interleave(vn_2//2)
                pred_x.append(px)
                pred_y.append(py)
        
        if len(batch_idx_used) != 0:
            pred_dx = torch.stack(pred_dx, 0)
            pred_dy = torch.stack(pred_dy, 0)
            pred_x = torch.stack(pred_x, 0) - 0.5
            pred_y = torch.stack(pred_y, 0) - 0.5

            pred_xydxdy = torch.stack([pred_x, pred_y, pred_dx, pred_dy], 2)
            pred_xydxdy = pred_xydxdy.permute(0,2,1)

            pred_pose = self.single_stage(pred_xydxdy)
        
        batch_pred_pose = torch.zeros(batch_size, 3, 4).cuda()
        for idx,b in enumerate(batch_idx_used):
            pp = pred_pose[idx]
            pq = pp[:4]
            pt = pp[4:]
            pr = quaternion2rotation(pq)
            prr = torch.cat([pr, pt.unsqueeze(1)],1)
            batch_pred_pose[b] = prr
                

        ret = {'seg': seg_pred, 'vertex': ver_pred, 'pred_pose':batch_pred_pose}

        return [seg_pred, ver_pred, batch_pred_pose]


def get_res_pvnet(ver_dim, seg_dim):

    model = Resnet18(ver_dim, seg_dim)
    return model



In [46]:
model = Resnet18(18,2).cuda()

In [3]:
import hiddenlayer as hl
hl.build_graph(model, torch.ones(1,3,256,256))

AttributeError: module 'torch.jit' has no attribute '_get_trace_graph'

In [47]:
x = torch.ones(1,3,256,256).cuda()
y = model(x)

In [48]:
from torchviz import make_dot_from_trace
with torch.onnx.set_training(model, False):
    trace, _ = torch.jit.get_trace_graph(model, args=(x,))
make_dot_from_trace(trace)



RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible