In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # defining vgg block here
        # block 1
        self.conv64_vgg_in = nn.Conv2d(3, 64, 3, 1, 1)
        self.conv64_vgg = nn.Conv2d(64, 64, 3, 1, 1)
        
        # block 2
        self.conv128_vgg_in = nn.Conv2d(64, 128, 3, 1, 1)
        self.conv128_vgg = nn.Conv2d(128, 128, 3, 1, 1)
        
        # block 3
        self.conv256_vgg_in = nn.Conv2d(128, 256, 3, 1, 1)
        self.conv256_vgg = nn.Conv2d(256, 256, 3, 1, 1)
        
        # block 4
        self.conv512_vgg_in = nn.Conv2d(256, 512, 3, 1, 1)
        self.conv512_vgg = nn.Conv2d(512, 512, 3, 1, 1)
        
        # additional non vgg layers
        self.conv_extra1 = nn.Conv2d(512, 256, 3, 1, 1)
        self.conv_extra2 = nn.Conv2d(256, 128, 3, 1, 1)
        
        
        # defining stage 1 layers here
        self.stage_1_block = nn.Conv2d(128, 128, 3, 1, 1)
        self.stage_1_block_2nd_last = nn.Conv2d(128, 512, 1, 1, 0)
        self.stage_1_block_paf = nn.Conv2d(512, 38, 1, 1, 0)
        self.stage_1_block_confidence_map = nn.Conv2d(512, 19, 1, 1, 0)
        
        # defining stage T layers here
        self.stage_T_block_in = nn.Conv2d(185, 128, 7, 1, 3) #the concatenated layer before has 19 + 38 + 128 = 185 channels
        self.stage_T_block = nn.Conv2d(128, 128, 7, 1, 3)
        self.stage_T_block_2nd_last = nn.Conv2d(128, 128, 1, 1, 0)
        self.stage_T_block_paf = nn.Conv2d(128, 38, 1, 1, 0)
        self.stage_T_block_confidence_map = nn.Conv2d(128, 19, 1, 1, 0)
        
    def stage_1(self, x, last_layer):
        x = self.stage_1_block(x)
        x = F.relu(x)
        x = self.stage_1_block(x)
        x = F.relu(x)
        x = self.stage_1_block(x)
        x = F.relu(x)
        x = self.stage_1_block_2nd_last(x)
        x = F.relu(x)
        if(last_layer == 38):
            x = self.stage_1_block_paf(x)
        else:
            x = self.stage_1_block_confidence_map(x)
        return x
        
    def stage_T(self, x, last_layer):
        x = self.stage_T_block_in(x)
        x = F.relu(x)
        x = self.stage_T_block(x)
        x = F.relu(x)
        x = self.stage_T_block(x)
        x = F.relu(x)
        x = self.stage_T_block(x)
        x = F.relu(x)
        x = self.stage_T_block(x)
        x = F.relu(x)
        x = self.stage_T_block_2nd_last(x)
        x = F.relu(x)
        if(last_layer == 38):
            x = self.stage_T_block_paf(x)
        else:
            x = self.stage_T_block_confidence_map(x)
        return x
        
    def forward(self, x):
        x = self.conv64_vgg_in(x)
        x = F.relu(x)
        x = self.conv64_vgg(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, stride = 2)
        
        x = self.conv128_vgg_in(x)
        x = F.relu(x)
        x = self.conv128_vgg(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, stride = 2)
        
        x = self.conv256_vgg_in(x)
        x = F.relu(x)
        x = self.conv256_vgg(x)
        x = F.relu(x)
        x = self.conv256_vgg(x)
        x = F.relu(x)
        x = self.conv256_vgg(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, stride = 2)
        
        x = self.conv512_vgg_in(x)
        x = F.relu(x)
        x = self.conv512_vgg(x)
        x = F.relu(x)
        
        x = self.conv_extra1(x)
        x = F.relu(x)
        x = self.conv_extra2(x)
        x_vgg = F.relu(x)
        
        # defining stage 1 blocks
        x_1_Paf = self.stage_1(x_vgg, 38)
        x_1_confidence_map = self.stage_1(x_vgg, 19)
        
        # concatenating x_vgg, x_1_Paf, x_1_confidence_map
        concat = torch.cat([x_1_Paf, x_1_confidence_map, x_vgg], 1)
        
        # defining stage T blocks
        for i in range(2, 7):
            x_paf = self.stage_T(concat, 38)
            x_confidence_map = self.stage_T(concat, 19)
            if(i != 6):
                concat = torch.cat([x_paf, x_confidence_map, x_vgg], 1)
        
        return x_paf, x_confidence_map

In [5]:
model = Net()
model.cuda()
# summary(model , (3, 224, 224))
inpt = torch.rand((1, 3, 224, 224))
inpt = inpt.cuda()
print(inpt.shape)
output_paf , output_confidence_map = model(inpt)
print(output_paf.shape, output_confidence_map.shape)

torch.Size([1, 3, 224, 224])
torch.Size([1, 38, 28, 28]) torch.Size([1, 19, 28, 28])


In [8]:
print("Printing models State_dict : ")
for param_tensor in model.state_dict():
    print(param_tensor, '\t', model.state_dict()[param_tensor].shape)

Printing models State_dict : 
conv64_vgg_in.weight 	 torch.Size([64, 3, 3, 3])
conv64_vgg_in.bias 	 torch.Size([64])
conv64_vgg.weight 	 torch.Size([64, 64, 3, 3])
conv64_vgg.bias 	 torch.Size([64])
conv128_vgg_in.weight 	 torch.Size([128, 64, 3, 3])
conv128_vgg_in.bias 	 torch.Size([128])
conv128_vgg.weight 	 torch.Size([128, 128, 3, 3])
conv128_vgg.bias 	 torch.Size([128])
conv256_vgg_in.weight 	 torch.Size([256, 128, 3, 3])
conv256_vgg_in.bias 	 torch.Size([256])
conv256_vgg.weight 	 torch.Size([256, 256, 3, 3])
conv256_vgg.bias 	 torch.Size([256])
conv512_vgg_in.weight 	 torch.Size([512, 256, 3, 3])
conv512_vgg_in.bias 	 torch.Size([512])
conv512_vgg.weight 	 torch.Size([512, 512, 3, 3])
conv512_vgg.bias 	 torch.Size([512])
conv_extra1.weight 	 torch.Size([256, 512, 3, 3])
conv_extra1.bias 	 torch.Size([256])
conv_extra2.weight 	 torch.Size([128, 256, 3, 3])
conv_extra2.bias 	 torch.Size([128])
stage_1_block.weight 	 torch.Size([128, 128, 3, 3])
stage_1_block.bias 	 torch.Size([128

In [10]:
model.load_state_dict(torch.load('pose_model.pth'))

RuntimeError: Error(s) in loading state_dict for Net:
	Missing key(s) in state_dict: "conv64_vgg_in.weight", "conv64_vgg_in.bias", "conv64_vgg.weight", "conv64_vgg.bias", "conv128_vgg_in.weight", "conv128_vgg_in.bias", "conv128_vgg.weight", "conv128_vgg.bias", "conv256_vgg_in.weight", "conv256_vgg_in.bias", "conv256_vgg.weight", "conv256_vgg.bias", "conv512_vgg_in.weight", "conv512_vgg_in.bias", "conv512_vgg.weight", "conv512_vgg.bias", "conv_extra1.weight", "conv_extra1.bias", "conv_extra2.weight", "conv_extra2.bias", "stage_1_block.weight", "stage_1_block.bias", "stage_1_block_2nd_last.weight", "stage_1_block_2nd_last.bias", "stage_1_block_paf.weight", "stage_1_block_paf.bias", "stage_1_block_confidence_map.weight", "stage_1_block_confidence_map.bias", "stage_T_block_in.weight", "stage_T_block_in.bias", "stage_T_block.weight", "stage_T_block.bias", "stage_T_block_2nd_last.weight", "stage_T_block_2nd_last.bias", "stage_T_block_paf.weight", "stage_T_block_paf.bias", "stage_T_block_confidence_map.weight", "stage_T_block_confidence_map.bias". 
	Unexpected key(s) in state_dict: "model0.0.weight", "model0.0.bias", "model0.2.weight", "model0.2.bias", "model0.5.weight", "model0.5.bias", "model0.7.weight", "model0.7.bias", "model0.10.weight", "model0.10.bias", "model0.12.weight", "model0.12.bias", "model0.14.weight", "model0.14.bias", "model0.16.weight", "model0.16.bias", "model0.19.weight", "model0.19.bias", "model0.21.weight", "model0.21.bias", "model0.23.weight", "model0.23.bias", "model0.25.weight", "model0.25.bias", "model1_1.0.weight", "model1_1.0.bias", "model1_1.2.weight", "model1_1.2.bias", "model1_1.4.weight", "model1_1.4.bias", "model1_1.6.weight", "model1_1.6.bias", "model1_1.8.weight", "model1_1.8.bias", "model2_1.0.weight", "model2_1.0.bias", "model2_1.2.weight", "model2_1.2.bias", "model2_1.4.weight", "model2_1.4.bias", "model2_1.6.weight", "model2_1.6.bias", "model2_1.8.weight", "model2_1.8.bias", "model2_1.10.weight", "model2_1.10.bias", "model2_1.12.weight", "model2_1.12.bias", "model3_1.0.weight", "model3_1.0.bias", "model3_1.2.weight", "model3_1.2.bias", "model3_1.4.weight", "model3_1.4.bias", "model3_1.6.weight", "model3_1.6.bias", "model3_1.8.weight", "model3_1.8.bias", "model3_1.10.weight", "model3_1.10.bias", "model3_1.12.weight", "model3_1.12.bias", "model4_1.0.weight", "model4_1.0.bias", "model4_1.2.weight", "model4_1.2.bias", "model4_1.4.weight", "model4_1.4.bias", "model4_1.6.weight", "model4_1.6.bias", "model4_1.8.weight", "model4_1.8.bias", "model4_1.10.weight", "model4_1.10.bias", "model4_1.12.weight", "model4_1.12.bias", "model5_1.0.weight", "model5_1.0.bias", "model5_1.2.weight", "model5_1.2.bias", "model5_1.4.weight", "model5_1.4.bias", "model5_1.6.weight", "model5_1.6.bias", "model5_1.8.weight", "model5_1.8.bias", "model5_1.10.weight", "model5_1.10.bias", "model5_1.12.weight", "model5_1.12.bias", "model6_1.0.weight", "model6_1.0.bias", "model6_1.2.weight", "model6_1.2.bias", "model6_1.4.weight", "model6_1.4.bias", "model6_1.6.weight", "model6_1.6.bias", "model6_1.8.weight", "model6_1.8.bias", "model6_1.10.weight", "model6_1.10.bias", "model6_1.12.weight", "model6_1.12.bias", "model1_2.0.weight", "model1_2.0.bias", "model1_2.2.weight", "model1_2.2.bias", "model1_2.4.weight", "model1_2.4.bias", "model1_2.6.weight", "model1_2.6.bias", "model1_2.8.weight", "model1_2.8.bias", "model2_2.0.weight", "model2_2.0.bias", "model2_2.2.weight", "model2_2.2.bias", "model2_2.4.weight", "model2_2.4.bias", "model2_2.6.weight", "model2_2.6.bias", "model2_2.8.weight", "model2_2.8.bias", "model2_2.10.weight", "model2_2.10.bias", "model2_2.12.weight", "model2_2.12.bias", "model3_2.0.weight", "model3_2.0.bias", "model3_2.2.weight", "model3_2.2.bias", "model3_2.4.weight", "model3_2.4.bias", "model3_2.6.weight", "model3_2.6.bias", "model3_2.8.weight", "model3_2.8.bias", "model3_2.10.weight", "model3_2.10.bias", "model3_2.12.weight", "model3_2.12.bias", "model4_2.0.weight", "model4_2.0.bias", "model4_2.2.weight", "model4_2.2.bias", "model4_2.4.weight", "model4_2.4.bias", "model4_2.6.weight", "model4_2.6.bias", "model4_2.8.weight", "model4_2.8.bias", "model4_2.10.weight", "model4_2.10.bias", "model4_2.12.weight", "model4_2.12.bias", "model5_2.0.weight", "model5_2.0.bias", "model5_2.2.weight", "model5_2.2.bias", "model5_2.4.weight", "model5_2.4.bias", "model5_2.6.weight", "model5_2.6.bias", "model5_2.8.weight", "model5_2.8.bias", "model5_2.10.weight", "model5_2.10.bias", "model5_2.12.weight", "model5_2.12.bias", "model6_2.0.weight", "model6_2.0.bias", "model6_2.2.weight", "model6_2.2.bias", "model6_2.4.weight", "model6_2.4.bias", "model6_2.6.weight", "model6_2.6.bias", "model6_2.8.weight", "model6_2.8.bias", "model6_2.10.weight", "model6_2.10.bias", "model6_2.12.weight", "model6_2.12.bias". 