In [3]:
import os
import cv2
from PIL import Image
import albumentations as A
import torch
from torch import nn
from albumentations.pytorch import ToTensorV2
import numpy as np
from torchvision.utils import save_image
from tqdm import tqdm

In [4]:
class GeneratorBlockDown(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int
    ):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode='reflect',
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    def forward(self, x):
        return self.block(x)

class GeneratorBlockUp(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            dropout: bool
    ):
        super().__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5) if dropout else nn.Identity()
        )
    def forward(self, x):
        return self.block(x)

class Generator(nn.Module):
    def __init__(self, in_channels: int=3, features: int=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=features,
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode='reflect'
            ),
            nn.LeakyReLU(0.2)
        )
        self.down1 = GeneratorBlockDown(in_channels=features, out_channels=features*2)
        self.down2 = GeneratorBlockDown(in_channels=features*2, out_channels=features*4)
        self.down3 = GeneratorBlockDown(in_channels=features*4, out_channels=features*8)
        self.down4 = GeneratorBlockDown(in_channels=features*8, out_channels=features*8)
        self.down5 = GeneratorBlockDown(in_channels=features*8, out_channels=features*8)
        self.down6 = GeneratorBlockDown(in_channels=features*8, out_channels=features*8)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels=features*8,
                      out_channels=features*8,
                      kernel_size=(4,4),
                      stride=(2,2),
                      padding=(1,1),
                      padding_mode="reflect"),
            nn.ReLU()
        )

        self.up1 = GeneratorBlockUp(in_channels=features*8, out_channels=features*8, dropout=True)
        self.up2 = GeneratorBlockUp(in_channels=features*8*2, out_channels=features*8, dropout=True)
        self.up3 = GeneratorBlockUp(in_channels=features*8*2, out_channels=features*8, dropout=True)
        self.up4 = GeneratorBlockUp(in_channels=features*8*2, out_channels=features*8, dropout=False)
        self.up5 = GeneratorBlockUp(in_channels=features*8*2, out_channels=features*4, dropout=False)
        self.up6 = GeneratorBlockUp(in_channels=features*4*2, out_channels=features*2, dropout=False)
        self.up7 = GeneratorBlockUp(in_channels=features*4, out_channels=features, dropout=False)

        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=features*2,
                out_channels=in_channels,
                kernel_size=(4,4),
                stride=(2,2),
                padding=(1,1)
            ),
            nn.Tanh()
        )
    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        # print(d7.shape, up1.shape)
        up2 = self.up2(torch.cat([d7, up1], dim=1))
        up3 = self.up3(torch.cat([d6, up2], dim=1))
        up4 = self.up4(torch.cat([d5, up3], dim=1))
        up5 = self.up5(torch.cat([d4, up4], dim=1))
        up6 = self.up6(torch.cat([d3, up5], dim=1))
        up7 = self.up7(torch.cat([d2, up6], dim=1))
        return self.final_up(torch.cat([d1, up7], dim=1))

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
state = torch.load('pix2pix_gen_vgg_15.pth', map_location=device)['state_dict']
model = Generator().to(device)
model.load_state_dict(state)
model.eval()
print('model loaded')

model loaded


In [6]:
path_dogs = '../../dataset/test_set/test_set/dogs'
path_cats = '../../dataset/test_set/test_set/cats'

In [7]:
img_path = os.path.join(path_dogs, os.listdir(path_dogs)[0])

In [8]:
formatting = A.Sequential([
    A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ToTensorV2()
])

In [9]:
for i in tqdm(os.listdir(path_dogs)):
    if i != '_DS_Store':
        img_path = os.path.join(path_dogs, i)
        image = Image.open(img_path)
        orig = image.resize((256, 256))
        bad_image = orig.quantize(8).convert('RGB')
        format_image = formatting(image=np.asarray(bad_image))['image'].to(device)
        with torch.no_grad():
            res_image = model(format_image.unsqueeze(0)).detach().cpu()
        save_image((res_image*0.5+0.5).abs(), os.path.join('dogs_refactored/', i))
        orig.save(os.path.join('dogs_orig/', i))
        bad_image.save(os.path.join('dogs_bad/', i))

for i in tqdm(os.listdir(path_cats)):
    if i != '_DS_Store':
        img_path = os.path.join(path_cats, i)
        image = Image.open(img_path)
        orig = image.resize((256, 256))
        bad_image = orig.quantize(8).convert('RGB')
        format_image = formatting(image=np.asarray(bad_image))['image'].to(device)
        with torch.no_grad():
            res_image = model(format_image.unsqueeze(0)).detach().cpu()
        save_image((res_image*0.5+0.5).abs(), os.path.join('cats_refactored/', i))
        orig.save(os.path.join('cats_orig/', i))
        bad_image.save(os.path.join('cats_bad/', i))

100%|██████████| 1012/1012 [01:23<00:00, 12.07it/s]
100%|██████████| 1011/1011 [01:11<00:00, 14.09it/s]
