In [2]:
import torch
import torch.nn as nn

In [3]:
class DiscBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel=4, stride=2):
        super(DiscBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=out_channels,
                      kernel_size=kernel,
                      stride=stride,
                      bias=False,
                      padding_mode='reflect'),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(negative_slope=0.2)
        )
    def forward(self, x):
        return self.conv(x)

In [4]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=None, kernel=4, stride=2, padding=1):
        super().__init__()
        if not features:
            features = [64, 128, 256, 512]
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels*2, features[0], kernel_size=kernel, stride=stride, padding=padding, padding_mode='reflect'),
            nn.LeakyReLU(0.2)
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(DiscBlock(in_channels=in_channels,
                                   out_channels=feature,
                                   stride=1 if feature==features[-1] else 2,
                                   kernel=kernel
                                   ))
            in_channels = feature

        layers.append(nn.Conv2d(
            in_channels=in_channels,
            out_channels=1,
            kernel_size=kernel,
            stride = (1, 1),
            padding=1,
            padding_mode='reflect'
        ))
        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat((x,y), dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x

In [5]:
def test_discriminator():
    a = Discriminator()
    img = torch.ones(10, 3, 256, 256)
    print(a.forward(img, img).shape)

In [6]:
test_discriminator()

torch.Size([10, 1, 26, 26])


In [7]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act='relu', use_dropout=False):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=out_channels,
                      kernel_size=(4,4),
                      stride=(2,2),
                      padding=(1,1),
                      padding_mode='reflect')
            if down
            else nn.ConvTranspose2d(in_channels=in_channels,
                                    out_channels=out_channels,
                                    kernel_size=(4,4),
                                    stride=(2,2),
                                    padding=(1,1),
                                    bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act=='relu' else nn.LeakyReLU(0.2)
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.conv(x)
        x = self.dropout(x) if self.use_dropout else x
        return x

In [8]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=features,
                kernel_size=(4,4),
                stride=(2,2),
                padding=1,
                padding_mode='reflect'
            ),
            nn.LeakyReLU(0.2)
        )
        self.down1 = GeneratorBlock(in_channels=features, out_channels=features*2, down=True, act='leaky', use_dropout=False)
        self.down2 = GeneratorBlock(in_channels=features*2, out_channels=features*4, down=True, act='leaky', use_dropout=False)
        self.down3 = GeneratorBlock(in_channels=features*4, out_channels=features*8, down=True, act='leaky', use_dropout=False)
        self.down4 = GeneratorBlock(in_channels=features*8, out_channels=features*8, down=True, act='leaky', use_dropout=False)
        self.down5 = GeneratorBlock(in_channels=features*8, out_channels=features*8, down=True, act='leaky', use_dropout=False)
        self.down6 = GeneratorBlock(in_channels=features*8, out_channels=features*8, down=True, act='leaky', use_dropout=False)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels=features*8,
                      out_channels=features*8,
                      kernel_size=(4,4), stride=(2,2),
                      padding=(1,1),
                      padding_mode="reflect"),
            nn.ReLU()
        )
        self.up1 = GeneratorBlock(in_channels=features*8, out_channels=features*8, down=False, act='relu', use_dropout=True)
        self.up2 = GeneratorBlock(in_channels=features*8*2, out_channels=features*8, down=False, act='relu', use_dropout=True)
        self.up3 = GeneratorBlock(in_channels=features*8*2, out_channels=features*8, down=False, act='relu', use_dropout=True)
        self.up4 = GeneratorBlock(in_channels=features*8*2, out_channels=features*8, down=False, act='relu', use_dropout=False)
        self.up5 = GeneratorBlock(in_channels=features*8*2, out_channels=features*4, down=False, act='relu', use_dropout=False)
        self.up6 = GeneratorBlock(in_channels=features*4*2, out_channels=features*2, down=False, act='relu', use_dropout=False)
        self.up7 = GeneratorBlock(in_channels=features*2*2, out_channels=features, down=False, act='relu', use_dropout=False)

        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(in_channels=features*2, out_channels=in_channels,kernel_size=(4,4), stride=(2,2), padding=(1,1)),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([d7, up1], dim=1))
        up3 = self.up3(torch.cat([d6, up2], dim=1))
        up4 = self.up4(torch.cat([d5, up3], dim=1))
        up5 = self.up5(torch.cat([d4, up4], dim=1))
        up6 = self.up6(torch.cat([d3, up5], dim=1))
        up7 = self.up7(torch.cat([d2, up6], dim=1))
        return self.final_up(torch.cat([d1, up7], dim=1))

In [9]:
def test_generator():
    test_imgs = torch.ones(10, 3, 256, 256)
    gen = Generator(3, 64)
    print(gen(test_imgs).shape)

In [10]:
test_generator()



torch.Size([10, 3, 256, 256])


In [32]:
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2

  warn(f"Failed to load image Python extension: {e}")


In [45]:
class MapDataset(Dataset):
    def __init__(self, dir):
        super().__init__()
        self.dir = dir
        self.all_files = os.listdir(dir)
        self.both_transform = A.Compose(
            [
                A.Resize(width=256, height=256),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5)
            ],
            additional_targets={
                'image0': 'image'
            }
        )
        self.transform_in = A.Compose(
            transforms=[A.ColorJitter(p=0.2),
                        A.Normalize(mean=[0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]),
                        ToTensorV2()]
        )
        self.transform_out = A.Compose(
            transforms=[
                 A.Normalize(mean=[0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]),
                 ToTensorV2()
            ]
        )
    def __len__(self):
        return len(self.all_files)
    def __getitem__(self, idx):
        img_path = os.path.join(self.dir, self.all_files[idx])
        image = np.array(Image.open(img_path))
        in_image = image[:, :600, :]
        out_image = image[:, 600:, :]
        augmentations = self.both_transform(image=in_image, image0=out_image)
        in_image, out_image = augmentations['image'], augmentations['image0']
        in_image = self.transform_in(image=in_image)['image']
        out_image = self.transform_out(image=out_image)['image']
        return in_image, out_image

In [46]:
dataset = MapDataset('maps/train')

In [50]:
dataset[0][1].shape

torch.Size([3, 256, 256])

In [31]:
np.array(Image.open('./maps/train/1.jpg')).shape

(600, 1200, 3)