In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
from torchvision.ops import roi_pool
from wsddn import *
from collections import OrderedDict 
from models import *

In [2]:
def init_parameters(module):
    if type(module) in [nn.Conv2d, nn.Linear]:
        torch.nn.init.normal_(module.weight, mean=0.0, std=1e-2)
        torch.nn.init.zeros_(module.bias)

In [3]:
def copy_parameters(src, target):
    assert src.weight.size() == target.weight.size()
    assert src.bias.size() == target.bias.size()
    src.weight = target.weight
    src.bias = target.bias

In [5]:
class Combined_Alexnet(nn.Module):
    def __init__(self, K=3, groups=4):
        super(Combined_Alexnet, self).__init__()
        self.K = K
        self.groups = groups
#         alexnet = torchvision.models.alexnet(pretrained=True)
        wsddn_alexnet = WSDDN_Alexnet()
        wsddn_alexnet.load_state_dict(torch.load("../pretrained/eb_2007_wsddn_alexnet.pt"))
        self.pretrained_features = nn.Sequential(*list(wsddn_alexnet.features[:5]._modules.values()))
        self.new_features = nn.Sequential(OrderedDict([
            ('conv3', nn.Conv2d(192, 384, kernel_size=3, stride=1, padding=2, dilation=2)),
            ('relu3', nn.ReLU(inplace=True)),
            ('conv4', nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=2, dilation=2)),
            ('relu4', nn.ReLU(inplace=True)),
            ('conv5', nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=2, dilation=2)),
            ('relu5', nn.ReLU(inplace=True)),
        ]))
        
        copy_parameters(self.new_features.conv3, wsddn_alexnet.features[6])
        copy_parameters(self.new_features.conv4, wsddn_alexnet.features[8])
        copy_parameters(self.new_features.conv5, wsddn_alexnet.features[10])
        
        self.roi_size = (6, 6)
        self.roi_spatial_scale= 0.125
        
        
        self.fc67 = nn.Sequential(*list(wsddn_alexnet.fc67._modules.values()))
        self.fc8c = wsddn_alexnet.fc8c
        self.fc8d = wsddn_alexnet.fc8d
        self.c_softmax = nn.Softmax(dim=1)
        self.d_softmax = nn.Softmax(dim=0)
        for i in range(self.K):
            self.add_module(
                f'refine{i}',
                nn.Sequential(OrderedDict([
#                     (f'groupNorm', nn.GroupNorm(self.groups, 4096)),
                    (f'ic_score{i}', nn.Linear(4096, 21)),
                    (f'ic_probs{i}', nn.Softmax(dim=1))
                ])))
            
    def forward(self, x, regions):
        regions = [regions[0]] # roi_pool require [Tensor(K, 4)]
        R = len(regions[0])
        features = self.new_features(self.pretrained_features(x))
        pool_features = roi_pool(features, regions, self.roi_size, self.roi_spatial_scale).view(R, -1)
        fc7 = self.fc67(pool_features)
        c_score = self.c_softmax(self.fc8c(fc7))
        d_score = self.d_softmax(self.fc8d(fc7))
        proposal_scores = c_score * d_score

        refine_scores = []
        for i in range(self.K):
            refine_scores.append(self._modules[f'refine{i}'](fc7))
        return refine_scores, proposal_scores

In [4]:
model = Combined_VGG16()

In [5]:
model

Combined_VGG16(
  (pretrained_features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2,

In [14]:
model.new_features.conv5_1.weight

Parameter containing:
tensor([[[[-2.9665e-02, -1.8004e-02, -2.0759e-02],
          [-2.3392e-02, -8.0310e-03, -1.2426e-02],
          [-1.7737e-02, -1.7743e-02, -2.7021e-02]],

         [[-2.1687e-02, -2.5525e-02, -6.7156e-03],
          [-1.3061e-02, -1.7542e-02, -1.5949e-02],
          [-2.4782e-02, -1.0104e-02, -4.9635e-03]],

         [[-4.2964e-03, -3.6401e-03, -6.3512e-03],
          [ 1.0383e-02,  2.3713e-02,  1.4648e-02],
          [-9.3529e-03, -1.2316e-03, -8.4364e-03]],

         ...,

         [[ 2.5568e-02,  3.8466e-02,  2.1158e-02],
          [ 4.5726e-02,  5.0937e-02,  3.1546e-02],
          [ 3.8903e-02,  4.3654e-02,  3.7143e-02]],

         [[ 2.6901e-02,  1.0798e-02,  1.9648e-02],
          [ 2.8889e-03, -8.7730e-03,  1.4934e-02],
          [ 4.2627e-03, -8.3357e-03,  2.3783e-02]],

         [[ 1.5674e-02,  1.1260e-02,  1.9177e-02],
          [ 2.7215e-02,  9.9037e-03,  2.5009e-02],
          [ 9.9336e-03,  7.7196e-03, -1.3929e-03]]],


        [[[-2.3820e-02, -2.3844

In [8]:
vgg = torchvision.models.vgg16(pretrained=True)

In [16]:
vgg.features[24].bias == model.new_features.conv5_1.bias

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, Tr