In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
#get_ipython().run_line_magic('pylab', 'inline')
import matplotlib.image as mpimg
import os
from glob import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torch
import torchvision
import random
from PIL import Image
import PIL.ImageOps
from torch.utils.tensorboard import SummaryWriter

In [None]:
class Background_Removal_Dataset(torch.utils.data.Dataset):
    def _init_(self, input_img,trimap_img):
        
        self.input_img = input_img
        self.trimap_img = trimap_img
        self.transform = transforms.Compose([transforms.Resize((320,240)),transforms.ToTensor()])

    def _len_(self):
        return len(self.input_img)

    def _getitem_(self, idx):
        input_image = Image.open(self.input_img)
        trimap_image = Image.open(self.trimap_img)

        return torch.cat((self.transform(input_image), self.transform(trimap_image)),0)

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=False),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=False)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)
    

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


In [None]:
net = UNet(4,4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net.to(device=device)
net.load_state_dict(torch.load('32generator_final.pth',map_location=torch.device('cpu')))

In [None]:
dataset = Background_Removal_Dataset('./input_images/298.jpg','./trimap_images/trimap_298.jpg')
train_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
for x in train_loader:
    test_sr = net(x)

In [None]:
import torchvision.utils as vutils
vutils.save_image(test_sr.data, 'generated_image.png', normalize = True)