In [1]:
import os
import cv2
from PIL import Image
import albumentations as A
import torch
from torch import nn
from albumentations.pytorch import ToTensorV2
import numpy as np
from torchvision.utils import save_image
from tqdm import tqdm

In [2]:
for path in ['cats_orig', 'cats_small', 'cats_bicubic', 'cats_nearest', 'cats_refactored']:
    if not os.path.isdir(path):
        os.mkdir(path)

In [3]:
for path in ['dogs_orig', 'dogs_small', 'dogs_bicubic', 'dogs_nearest', 'dogs_refactored']:
    if not os.path.isdir(path):
        os.mkdir(path)

In [4]:
class ResidualBlock(nn.Module):
    def __init__(
            self,
            in_channels = 64,
            out_channels = 64,
            stride = 1,
            kernel_size = 3,
            padding = 1,
            padding_mode='reflect'
    ):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels = in_channels,
                               out_channels = out_channels,
                               kernel_size = kernel_size,
                               stride = stride,
                               padding= padding,
                               bias=False)

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.activation = nn.PReLU(num_parameters=out_channels)

        self.conv2 = nn.Conv2d(in_channels = in_channels,
                               out_channels = out_channels,
                               kernel_size = kernel_size,
                               stride = stride,
                               padding= padding,
                               bias=False)

        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = out + x
        return out

In [5]:
class GeneratorUpsample(nn.Module):
    def __init__(
            self,
            in_channels = 64,
            upsample_value = 2,
            stride = 1,
            kernel_size = 3,
            padding = 1
    ):
        super().__init__()
        self.upsample_value = upsample_value
        self.in_channels = in_channels
        self.stride = stride
        self.kernel_size = kernel_size
        self.padding = padding

        self.conv = nn.Conv2d(in_channels=self.in_channels,
                              out_channels=self.in_channels*upsample_value*upsample_value,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              stride = self.stride)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor=upsample_value)
        self.activation = nn.PReLU(in_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.activation(x)
        return x

In [6]:
class Generator(nn.Module):
    def __init__(
            self,
            in_channels = 3,
            out_channels = 3,
            hidden_channels = 64,
            num_res_blocks = 16
    ):
        super().__init__()

        self.in_channels = in_channels
        self.num_res_blocks = num_res_blocks
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels

        self.initial_conv = nn.Conv2d(
            in_channels = self.in_channels,
            out_channels = self.hidden_channels,
            stride = 1,
            kernel_size = 9,
            padding = 4
        )
        self.activation = nn.PReLU(self.hidden_channels)
        self.residual_blocks = nn.Sequential(*[ResidualBlock() for _ in range(num_res_blocks)])
        self.middle_conv = nn.Conv2d(
            in_channels = self.hidden_channels,
            out_channels = self.hidden_channels,
            stride = 1,
            kernel_size = 3,
            padding = 1,
            bias=False
        )
        self.bn = nn.BatchNorm2d(self.hidden_channels)
        self.final_conv = nn.Conv2d(
            in_channels = self.hidden_channels,
            out_channels = self.out_channels,
            stride = 1,
            kernel_size = 9,
            padding_mode='reflect',
            padding = 4
        )
        self.up1 = GeneratorUpsample()
        self.up2 = GeneratorUpsample()


    def forward(self, x):
        x = self.initial_conv(x)
        x = self.activation(x)
        out = self.residual_blocks(x)
        out = self.middle_conv(out)
        out = self.bn(out)
        out = out + x
        out = self.up1(out)
        out = self.up2(out)
        out = self.final_conv(out)
        return out

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
state = torch.load('srgan_gen_430.pth', map_location=device)['state_dict']
model = Generator().to(device)
model.load_state_dict(state)
model.eval()
print('model loaded')

model loaded


In [8]:
path_dogs = '../../dataset/test_set/test_set/dogs'
path_cats = '../../dataset/test_set/test_set/cats'

In [9]:
img_path = os.path.join(path_dogs, os.listdir(path_dogs)[0])

In [10]:
prep = A.Compose(
            [
                A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
                ToTensorV2()
            ]
        )

In [14]:
for i in tqdm(os.listdir(path_dogs)[:10]):
    if i != '_DS_Store':
        img_path = os.path.join(path_dogs, i)
        image = Image.open(img_path)


        image_shape = np.asarray(image).shape

        image_small = A.resize(np.asarray(image), image_shape[0]//4, image_shape[1]//4, interpolation=Image.Resampling.BICUBIC)
        image_bicubic = A.resize(np.asarray(image_small), image_shape[0], image_shape[1], interpolation=Image.Resampling.BICUBIC)
        image_nearest = A.resize(np.asarray(image_small), image_shape[0], image_shape[1], interpolation=Image.Resampling.BILINEAR)

        format_image = prep(image = image_small)['image'].to(device)
        with torch.no_grad():
            res_image = model(format_image.unsqueeze(0)).detach().cpu()
        save_image((res_image).abs(), os.path.join('dogs_refactored/', i))

        Image.fromarray(image_small).save(os.path.join('dogs_small/', i))
        Image.fromarray(image_bicubic).save(os.path.join('dogs_bicubic/', i))
        Image.fromarray(image_nearest).save(os.path.join('dogs_nearest/', i))
        image.save(os.path.join('dogs_orig/', i))


100%|██████████| 10/10 [00:00<00:00, 13.91it/s]


In [20]:
np.asarray(image).shape[0]//4

124

In [24]:
A.resize(image_shape[0]//4, image_shape[1]//4, interpolation=Image.BICUBIC)

  A.resize((image_shape[0]//4, image_shape[1]//4), interpolation=Image.BICUBIC)


AttributeError: 'tuple' object has no attribute 'shape'

In [12]:
Image.NEAREST

  Image.NEAREST


0