In [3]:
import torch
import torch.nn as nn
import torchvision.models as models

class HourglassPose(nn.Module):
    def __init__(self):
        super(HourglassPose, self).__init__()
        
        # Load pre-trained ResNet34 as encoder
        resnet34 = models.resnet34(pretrained=True)
        self.encoder = nn.Sequential(*list(resnet34.children())[:-2])  # Remove last avgpool and fc layers
        
        # Decoder
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv4 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv_final = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        
        # Skip connections
        self.skip1 = nn.Conv2d(512, 32, kernel_size=1)
        self.skip2 = nn.Conv2d(256, 32, kernel_size=1)
        self.skip3 = nn.Conv2d(128, 32, kernel_size=1)
        self.skip4 = nn.Conv2d(64, 32, kernel_size=1)
        
        # Regressor
        self.fc_loc = nn.Linear(32 * 56 * 56, 3)
        self.fc_ori = nn.Linear(32 * 56 * 56, 4)
        self.fc_trans = nn.Linear(32 * 56 * 56, 3)
        
    def forward(self, x):
        # Encoder
        x = self.encoder(x)
        skip1 = self.skip1(x)
        
        # Decoder
        x = self.deconv1(x)
        skip2 = self.skip2(x)
        x = torch.cat((x, skip2), dim=1)
        
        x = self.deconv2(x)
        skip3 = self.skip3(x)
        x = torch.cat((x, skip3), dim=1)
        
        x = self.deconv3(x)
        skip4 = self.skip4(x)
        x = torch.cat((x, skip4, skip1), dim=1)
        
        x = self.deconv4(x)
        x = self.conv_final(x)
        
        # Regressor
        x = x.view(x.size(0), -1)
        loc = self.fc_loc(x)
        ori = self.fc_ori(x)
        trans = self.fc_trans(x)
        
        return loc, ori, trans

In [1]:
import torch
x = torch.rand(5, 3)
print(x)

tensor([[0.1716, 0.0628, 0.9276],
        [0.7938, 0.2581, 0.9955],
        [0.2131, 0.1500, 0.4318],
        [0.4466, 0.2244, 0.0838],
        [0.5996, 0.4756, 0.7389]])


In [2]:
import torch
torch.cuda.is_available()

True