In [29]:
import random

import torch
import torch.cuda as cuda
import torch.nn as nn
import torch.nn.functional as F

In [30]:
class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        
        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False 
            
        if 'return_mask' in kwargs:
            self.return_mask = kwargs['return_mask']
            kwargs.pop('return_mask')
        else:
            self.return_mask = False
            
        super(PartialConv2d, self).__init__(*args, **kwargs)
        
        if self.multi_channel:
            self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])
        else:
            self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])
            
        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]

        self.last_size = (None, None)
        self.update_mask = None
        self.mask_ratio = None 
        
    def forward(self, input, mask=None):
        if mask is not None or self.last_size != (input.data.shape[2], input.data.shape[3]):
            self.last_size = (input.data.shape[2], input.data.shape[3])
            
            with torch.no_grad():
                if self.weight_maskUpdater.type() != input.type():
                    self.weight_maskUpdater = self.weight_maskUpdater.to(input)
            
                if mask is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)
                        
                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)
                
                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
                
        if self.update_mask.type() != input.type() or self.mask_ratio.type() != input.type():
            self.update_mask.to(input)
            self.mask_ratio.to(input)

        raw_out = super(PartialConv2d, self).forward(input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)


        if self.return_mask:
            return output, self.update_mask
        else:
            return output

In [31]:
class Conv_ReLU_Block(nn.Module):
    def __init__(self, use_partial_conv=False):
        super(CR_Block, self).__init__()
        if use_partial_conv:
            self.conv = PartialConv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        else:
            self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, input):
        output = self.conv(input)
        output = self.relu(output)
        return output

In [32]:
class SRNet(nn.Module):
    def __init__(self, use_partial_conv):
        super(SRNet, self).__init__()
        self.use_partial_conv = use_partial_conv
        self.residual_layer = self.make_layer(CR_Block, 18)
        if self.use_partial_conv:
            self.input = PartialConv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
            self.output = PartialConv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
        else:
            self.input = PartialConv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
            self.output = PartialConv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
    
        '''
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, sqrt(2. / n))
        '''
                
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, PartialConv2d):
                nn.init.kaiming_normal(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant(m.weight, 1)
                nn.init.constant(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal(m.weight, 0, 0.01)
                nn.init.constant(m.bias, 0)
            
    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block(use_partial_conv=self.use_partial_conv))
        return nn.Sequential(*layers)

    def forward(self, input):
        residual = input
        output = self.relu(self.input(input))
        output = self.residual_layer(output)
        output = self.output(output)
        output = torch.add(output,residual)
        return out

In [28]:
cuda.manual_seed(opt.seed)


model = SRNet(use_partial_conv=True).cuda()
criterion = nn.MSELoss(size_average=False).cuda()