In [2]:
import numpy as np

In [21]:
import math
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models
# from utils.helpers import initialize_weights, set_trainable
from itertools import chain
import sys
sys.path.append("/mnt/batch/tasks/shared/LS_root/mounts/clusters/pyfuse/code/pyramid-fuse")
from Utils import *


In [4]:
class _PSPModule(nn.Module):
    def __init__(self, in_channels, bin_sizes, norm_layer):
        super(_PSPModule, self).__init__()
        out_channels = in_channels // len(bin_sizes)
        self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s, norm_layer) 
                                                        for b_s in bin_sizes])
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), out_channels, 
                                    kernel_size=3, padding=1, bias=False),
            norm_layer(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1)
        )

    def _make_stages(self, in_channels, out_channels, bin_sz, norm_layer):
        prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
        conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        bn = norm_layer(out_channels)
        relu = nn.ReLU(inplace=True)
        return nn.Sequential(prior, conv, bn, relu)
    
    def forward(self, features):
        h, w = features.size()[2], features.size()[3]
        pyramids = [features]
        pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear', 
                                        align_corners=True) for stage in self.stages])
        output = self.bottleneck(torch.cat(pyramids, dim=1))
        return output

In [6]:
class PSPNet(nn.Module):
    def __init__(self, num_classes=21, use_aux=True):
        super(PSPNet, self).__init__()
        # TODO: Use synch batchnorm
        norm_layer = nn.InstanceNorm2d
#         model = getattr(resnet, backbone)(pretrained, norm_layer=norm_layer, )
#         m_out_sz = model.fc.in_features
        self.use_aux = use_aux 

#         self.initial = nn.Sequential(*list(model.children())[:4])
#         if in_channels != 3:
#             self.initial[0] = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
#         self.initial = nn.Sequential(*self.initial)
        
#         self.layer1 = model.layer1
#         self.layer2 = model.layer2
#         self.layer3 = model.layer3
#         self.layer4 = model.layer4

        self.master_branch = nn.Sequential(
            _PSPModule(1024, bin_sizes=[1, 2, 3, 6], norm_layer=norm_layer),
            nn.Conv2d(1024//4, num_classes, kernel_size=1)
        )

        self.auxiliary_branch = nn.Sequential(
            nn.Conv2d(1024, 1024//2, kernel_size=3, padding=1, bias=False),
            norm_layer(1024//2),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),
            nn.Conv2d(1024//2, num_classes, kernel_size=1)
        )

#         initialize_weights(self.master_branch, self.auxiliary_branch)
#         if freeze_bn: self.freeze_bn()
#         if freeze_backbone: 
#             set_trainable([self.initial, self.layer1, self.layer2, self.layer3, self.layer4], False)

    def forward(self, x):
        input_size = (512, 1024)
        if x.shape[0]%6==0:
            input_size = (32*x.size()[2], 32*x.size()[3])
#         x = self.initial(x)
#         x = self.layer1(x)
#         x = self.layer2(x)
#         x_aux = self.layer3(x)
#         x = self.layer4(x_aux)
        x_aux = x
        output = self.master_branch(x)
        output = F.interpolate(output, size=input_size, mode='bilinear')
        output = output[:, :, :input_size[0], :input_size[1]]

        if self.training and self.use_aux:
            aux = self.auxiliary_branch(x_aux)
            aux = F.interpolate(aux, size=input_size, mode='bilinear')
            aux = aux[:, :, :input_size[0], :input_size[1]]
            return output
        return output

In [7]:
psp = PSPNet(21)

In [8]:
feat_equi = torch.randn([1, 1024, 16, 32])
feat_cube = torch.randn([6, 1024, 8, 8])

In [9]:
feat_equi = psp(feat_equi)
feat_cube = psp(feat_cube)

  "See the documentation of nn.Upsample for details.".format(mode))


In [10]:
feat_equi.shape

torch.Size([1, 21, 512, 1024])

In [11]:
feat_cube.shape

torch.Size([6, 21, 256, 256])

In [24]:
class CETransform(nn.Module):
    def __init__(self):
        super(CETransform, self).__init__()
        equ_h = [512, 128, 64, 32, 16]
        cube_h = [256, 64, 32, 16, 8]

        self.c2e = dict()
        self.e2c = dict()

        for h in equ_h:
            a = Equirec2Cube(1, h, h*2, h//2, 90)
            self.e2c['(%d,%d)' % (h, h*2)] = a

        for h in cube_h:
            a = Cube2Equirec(1, h, h*2, h*4)
            self.c2e['(%d)' % (h)] = a

    def E2C(self, x):
        print(x.shape)
        [bs, c, h, w] = x.shape
        key = '(%d,%d)' % (h, w)
        print(key)
        assert key in self.e2c
        return self.e2c[key].ToCubeTensor(x)

    def C2E(self, x):
        print(x.shape)
        [bs, c, h, w] = x.shape
        key = '(%d)' % (h)
        assert key in self.c2e and h == w
        return self.c2e[key].ToEquirecTensor(x)

    def forward(self, equi, cube):
        return self.e2c(equi), self.c2e(cube)


In [25]:
ce = CETransform()

In [26]:
feat_cube = feat_cube.cuda()

In [27]:
feat_cube = ce.C2E(feat_cube)

torch.Size([6, 21, 256, 256])




In [28]:
feat_cube.shape

torch.Size([1, 21, 512, 1024])

In [176]:
feat_equi =  feat_equi.cuda() 

In [177]:
feat_equi.shape

torch.Size([1, 21, 512, 1024])

In [178]:
equi = torch.randn((1,3,512,1024))

In [179]:
feat_cat = torch.cat((feat_equi, feat_cube), dim = 1).cuda()
feat_cat.shape

torch.Size([1, 42, 512, 1024])

In [6]:
class Refine(nn.Module):
    def __init__(self):
        super(Refine, self).__init__()
        self.refine_1 = nn.Sequential(
                        nn.Conv2d(42, 32, kernel_size=3, stride=1, padding=1, bias=False),
                        nn.BatchNorm2d(32),
                        nn.ReLU(inplace=True),
                        nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
                        nn.BatchNorm2d(64),
                        nn.ReLU(inplace=True),
                        nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
                        nn.BatchNorm2d(64),
                        nn.ReLU(inplace=True)
                        )
        self.refine_2 = nn.Sequential(
                        nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
                        nn.BatchNorm2d(128),
                        nn.ReLU(inplace=True),
                        nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
                        nn.BatchNorm2d(128),
                        nn.ReLU(inplace=True),
                        )
        self.deconv_1 = nn.Sequential(
                        nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, output_padding=0, groups=1, bias=True, dilation=1),
                        nn.BatchNorm2d(64),
                        nn.LeakyReLU(inplace=True),
                        )
        self.deconv_2 = nn.Sequential(
                        nn.ConvTranspose2d(192, 32, kernel_size=4, stride=2, padding=1, output_padding=0, groups=1, bias=True, dilation=1),
                        nn.BatchNorm2d(32),
                        nn.LeakyReLU(inplace=True),
                        )
        self.refine_3 = nn.Sequential(
                        nn.Conv2d(96, 16, kernel_size=3, stride=1, padding=1, bias=False),
                        nn.BatchNorm2d(16),
                        nn.ReLU(inplace=True),
                        nn.Conv2d(16, 21, kernel_size=3, stride=1, padding=1, bias=False)
                        )
        self.bilinear_1 = nn.UpsamplingBilinear2d(size=(256,512))
        self.bilinear_2 = nn.UpsamplingBilinear2d(size=(512,1024))
    def forward(self, inputs):
        x = inputs
        out_1 = self.refine_1(x)
        out_2 = self.refine_2(out_1)
        deconv_out1 = self.deconv_1(out_2)
        up_1 = self.bilinear_1(out_2)
        deconv_out2 = self.deconv_2(torch.cat((deconv_out1, up_1), dim = 1))
        up_2 = self.bilinear_2(out_1)
        out_3 = self.refine_3(torch.cat((deconv_out2, up_2), dim = 1))

        return out_3                

In [7]:
refine_model = Refine().cuda()

In [182]:
refine = refine_model(feat_cat)

In [183]:
refine.shape

torch.Size([1, 21, 512, 1024])

In [184]:
refine

tensor([[[[ 0.4447,  0.7250,  0.4547,  ...,  0.4395,  0.1271, -0.0041],
          [ 0.2133,  0.3017,  0.4479,  ...,  0.5892,  0.2444, -0.1405],
          [ 0.2404,  0.5208,  0.4556,  ...,  0.2375,  0.2502, -0.1042],
          ...,
          [ 0.4903,  0.3270,  0.5936,  ...,  0.0253,  0.1814,  0.0502],
          [ 0.2983,  0.4196,  0.4389,  ...,  0.0290,  0.0433, -0.2088],
          [ 0.2918,  0.1080,  0.2433,  ..., -0.0273, -0.0520, -0.0960]],

         [[-0.0715,  0.0340,  0.1300,  ...,  0.2734, -0.0283,  0.0191],
          [-0.2055,  0.0406, -0.0135,  ..., -0.1115, -0.2554, -0.0389],
          [ 0.0036,  0.0234,  0.0876,  ..., -0.1132, -0.0589,  0.1346],
          ...,
          [ 0.4731,  0.4156,  0.3527,  ...,  0.4907,  0.2105,  0.1776],
          [ 0.5284,  0.1091,  0.3086,  ...,  0.1223,  0.0707,  0.2982],
          [ 0.1869,  0.1058,  0.1558,  ...,  0.0986,  0.1512,  0.0295]],

         [[-0.0485, -0.0502,  0.0602,  ..., -0.0523,  0.1897,  0.0685],
          [ 0.1236, -0.0082,  

In [23]:
params = torch.load("./BiFuse_Pretrained.pkl")

In [5]:
# params = torch.load(name)
model.load_state_dict(params, strict=False)

AttributeError: 'collections.OrderedDict' object has no attribute 'load_state_dict'

In [24]:
keys = list(params.keys())

In [50]:
keys

['equi_model.conv1.weight',
 'equi_model.bn1.weight',
 'equi_model.bn1.bias',
 'equi_model.bn1.running_mean',
 'equi_model.bn1.running_var',
 'equi_model.bn1.num_batches_tracked',
 'equi_model.layer1.0.conv1.weight',
 'equi_model.layer1.0.bn1.weight',
 'equi_model.layer1.0.bn1.bias',
 'equi_model.layer1.0.bn1.running_mean',
 'equi_model.layer1.0.bn1.running_var',
 'equi_model.layer1.0.bn1.num_batches_tracked',
 'equi_model.layer1.0.conv2.weight',
 'equi_model.layer1.0.bn2.weight',
 'equi_model.layer1.0.bn2.bias',
 'equi_model.layer1.0.bn2.running_mean',
 'equi_model.layer1.0.bn2.running_var',
 'equi_model.layer1.0.bn2.num_batches_tracked',
 'equi_model.layer1.0.conv3.weight',
 'equi_model.layer1.0.bn3.weight',
 'equi_model.layer1.0.bn3.bias',
 'equi_model.layer1.0.bn3.running_mean',
 'equi_model.layer1.0.bn3.running_var',
 'equi_model.layer1.0.bn3.num_batches_tracked',
 'equi_model.layer1.0.downsample.0.weight',
 'equi_model.layer1.0.downsample.1.weight',
 'equi_model.layer1.0.downsamp

In [47]:
sub = 'refine'

In [48]:
for i in keys:
    if sub in i:
        print (i)

refine_model.refine_1.0.weight
refine_model.refine_1.1.weight
refine_model.refine_1.1.bias
refine_model.refine_1.1.running_mean
refine_model.refine_1.1.running_var
refine_model.refine_1.1.num_batches_tracked
refine_model.refine_1.3.weight
refine_model.refine_1.4.weight
refine_model.refine_1.4.bias
refine_model.refine_1.4.running_mean
refine_model.refine_1.4.running_var
refine_model.refine_1.4.num_batches_tracked
refine_model.refine_1.6.weight
refine_model.refine_1.7.weight
refine_model.refine_1.7.bias
refine_model.refine_1.7.running_mean
refine_model.refine_1.7.running_var
refine_model.refine_1.7.num_batches_tracked
refine_model.refine_2.0.weight
refine_model.refine_2.1.weight
refine_model.refine_2.1.bias
refine_model.refine_2.1.running_mean
refine_model.refine_2.1.running_var
refine_model.refine_2.1.num_batches_tracked
refine_model.refine_2.3.weight
refine_model.refine_2.4.weight
refine_model.refine_2.4.bias
refine_model.refine_2.4.running_mean
refine_model.refine_2.4.running_var
refi

In [53]:
sub2 = 'dec'

for j in keys:
    if sub2 in j:
        print(j)

refine_model.deconv_1.0.weight
refine_model.deconv_1.0.bias
refine_model.deconv_1.1.weight
refine_model.deconv_1.1.bias
refine_model.deconv_1.1.running_mean
refine_model.deconv_1.1.running_var
refine_model.deconv_1.1.num_batches_tracked
refine_model.deconv_2.0.weight
refine_model.deconv_2.0.bias
refine_model.deconv_2.1.weight
refine_model.deconv_2.1.bias
refine_model.deconv_2.1.running_mean
refine_model.deconv_2.1.running_var
refine_model.deconv_2.1.num_batches_tracked
equi_decoder.layer1.upper_branch.conv1.weight
equi_decoder.layer1.upper_branch.batchnorm1.weight
equi_decoder.layer1.upper_branch.batchnorm1.bias
equi_decoder.layer1.upper_branch.batchnorm1.running_mean
equi_decoder.layer1.upper_branch.batchnorm1.running_var
equi_decoder.layer1.upper_branch.batchnorm1.num_batches_tracked
equi_decoder.layer1.upper_branch.conv2.weight
equi_decoder.layer1.upper_branch.batchnorm2.weight
equi_decoder.layer1.upper_branch.batchnorm2.bias
equi_decoder.layer1.upper_branch.batchnorm2.running_mean


In [54]:
for key in keys:
    if sub in key:
        keys.remove(key)

In [56]:
for key2 in keys:
    if sub2 in key2:
        keys.remove(key2)

In [57]:
len(keys)

853

In [65]:
len(params)

953

In [73]:
for key_ele in list(params.keys()):
    if sub in key_ele:
        del params[key_ele]

In [75]:
for key_ele2 in list(params.keys()):
    if sub2 in key_ele2:
        del params[key_ele2]

In [76]:
len(params)

760

In [81]:
new_keys = list(params.keys())

In [84]:
torch.save(params, 'params.pkl')