In [100]:
import torch.nn as nn
import torch 
from torchvision import  utils
import os 
import torch.nn.functional as F
import PIL
import numpy as np
from PIL import Image 


In [101]:
x = torch.load('../fake_img').cpu()
mask = torch.load('../global_seg').cpu()

In [129]:
class Mask_Enlarger(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1,1,3,1,1, bias=False)
        self.conv.weight.data.fill_(1/9)
        
        for param in self.parameters():
            param.requires_grad = False
            
    def forward(self, input, hard=1, soft=1):
        """
        soft means we apply 3*3 conv, and boundary values will have non-binaray value
        hard means after 3*3 conv as long as value is non zero, we will convert it into 1 
        """
        assert hard + soft > 0 
        
        if hard>0:
            x = input
            for _ in range(hard):
                x = self.conv(x) 
                x[x!=0] = 1 

        if soft>0:
            x = x if hard>0 else input
            for _ in range(soft):
                x = self.conv(x)

        return torch.clamp(input+x ,0 ,1)




In [102]:
class PartialConv(nn.Module):
    def __init__(self, kernel_size=3):
        super().__init__()
        "Bigger kernel size are, more pixels will be dialated"
        
        assert kernel_size in [3,5,7]
        padding = kernel_size // 2   
        
        self.kernel_size = kernel_size               
        self.mask_conv = nn.Conv2d(1, 1, kernel_size, 1, padding, bias=False)        
        self.mask_conv.weight.data.fill_(1.0)
        
        self.input_conv = nn.Conv2d(3, 3, kernel_size, 1, padding, bias=False) 
        for i in range(3):
            init = torch.zeros(3,kernel_size,kernel_size)
            init[i,:,:] = 1/(kernel_size*kernel_size) 
            self.input_conv.weight.data[i] = init
                   
        for param in self.parameters():
            param.requires_grad = False
             
 
    def forward(self, input, mask, return_new_mask = False):

        output = self.input_conv( input*mask )
        mask = self.mask_conv(mask)

        no_update_holes = mask == 0
        mask_ratio = (self.kernel_size*self.kernel_size) / mask.masked_fill_(no_update_holes, 1.0)

        output = output * mask_ratio 
        output = output.masked_fill_(no_update_holes, 0.0)

        if not return_new_mask:
            return output
        else:
            new_mask = 1-no_update_holes*1
            return output, new_mask.float()


In [123]:
class RGB_Enlarger(nn.Module):
    def __init__(self):
        super().__init__()
        self.enlarger = PartialConv()
        self.mask_enlarger = Mask_Enlarger()
 
    def __call__(self, x, mask):   
  
        enlarged_x, enlarged_mask = x, mask
        for _ in range(3):
            enlarged_x, enlarged_mask = self.enlarger(enlarged_x, enlarged_mask, True)

 
        return enlarged_x 

In [124]:
rgb_enlarger = RGB_Enlarger()

In [125]:
out = rgb_enlarger(  x,mask )

In [126]:
torchvision.utils.save_image(out, 'x.png', normalize=True, range=(-1,1)  )

In [18]:
import torchvision 