In [1]:
import torch
import math
from torch import nn, autograd, stack, cat
from torch.autograd import Variable
import torch.nn.functional as F

In [7]:
class ColorNet(nn.Module):
    def __init__(self, hooks=False, pretrained=None):
        super(ColorNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, 1)
        self.conv2 = nn.Conv2d(10, 3, 1)
        self.bn2 = nn.BatchNorm2d(3,affine=True)
        self.bn1 = nn.BatchNorm2d(10,affine=True)

        if pretrained is None:
            self.init_weights()
        else:
            state = torch.load(pretrained, map_location=lambda storage, loc: storage)
            self.load_state_dict(state['state_dict'])
        
        def printgradnorm(self, grad_input, grad_output):
            print('Inside ' + self.__class__.__name__ + ' backward')
            print('{} -> {}'.format(grad_input[0].size(),grad_output[0].size()))
            print('grad_in norm: {}'.format(grad_input[0].data.norm()))
            print('grad_out norm: {}'.format(grad_output[0].data.norm()))
                  
        if hooks:
            self.conv1.register_backward_hook(printgradnorm)
            self.conv2.register_backward_hook(printgradnorm)
     
    def forward(self, x):
        
        residual = torch.cat([x,x,x],dim=1)
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x += residual
        x = F.sigmoid(x) 
        
        return x
    
    def init_weights(self):
        """Initialize the weights."""
        self.conv1.weight.data.normal_(0, 0.02)
        self.conv1.bias.data.fill_(0)
        self.conv2.weight.data.normal_(0, 0.02)
        self.conv2.bias.data.fill_(0)

In [8]:
#c = ColorNet()
#c = ColorNet(pretrained='/home/frati/Grasping/models/colorer_model/model_best_acc52.pth.tar')

In [2]:
fake_batch = Variable(torch.rand(4,3,224,224))

In [None]:
a = torch.ones((2,2,3,3))

In [None]:
b = torch.ones((2,2,3,3))*2

In [11]:
torch.cat([a,b],dim=1).shape

torch.Size([2, 4, 3, 3])