In [6]:
# from unet import Unet
import torchvision as tv
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from enum import Enum


# N_CLASSES = 11


class SkipType(Enum):
    SKIP = 0
    NO_SKIP = 1
    ZERO_SKIP = 2


class UpBlock(nn.Module):
    def __init__(self, ch_out, skip=SkipType.SKIP, dropout=None):
        super(UpBlock, self).__init__()
        self.skip = skip
        self.upconv = nn.ConvTranspose2d(ch_out * 2, ch_out * 2, kernel_size=3, padding=1, stride=2, output_padding=1)
        self.conv1 = nn.Conv2d(ch_out * 2 if skip == SkipType.NO_SKIP else ch_out * 3, ch_out, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()

        if dropout is not None and dropout != 0:
            self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, u=None, padding=None):
        if u is None and self.skip:
            raise ValueError("Expected skip connection")

        if padding is None:
            padding = [x.shape[2] % 2, x.shape[3] % 2]

        u = self.upconv(u)
        u = F.pad(u, [padding[1], 0, padding[0], 0])

        if self.skip == SkipType.SKIP:
            x1 = torch.cat((x, u), dim=1)
        elif self.skip == SkipType.ZERO_SKIP:
            x1 = torch.cat((torch.zeros_like(x), u), dim=1)
        else:
            x1 = u

        x2 = self.relu1(self.conv1(x1))
        x3 = self.relu2(self.conv2(x2))
        if getattr(self, 'dropout', None) is not None:
            x3 = self.dropout(x3)
        return x3


class Unet(nn.Module):
    def __init__(self, layers, output_channels=3, skip=SkipType.SKIP, dropout=None):
        super(Unet, self).__init__()
        prev_layer = 3
        for i, layer in enumerate(layers, start=1):
            setattr(self, f'down{i}', self.enc_block(prev_layer, layer, max_pool=prev_layer != 3))
            prev_layer = layer
        for i, layer in reversed(list(enumerate(layers[:-1], start=1))):
            setattr(self, f'up{i}', UpBlock(layer, skip, dropout))
        self.final = nn.Conv2d(layers[0], output_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward_vars(self, x):
        children = list(self.named_children())
        vars = {}
        downs = list(map(lambda nm: nm[1], filter(lambda nm: 'down' in nm[0], children)))
        ups = list(map(lambda nm: nm[1], filter(lambda nm: 'up' in nm[0], children)))
        final = children[-2][1]
        sigmoid = children[-1][1]

        vars['x'] = x
        for i, down_i in enumerate(downs[:-1], start=1):
            vars[f'x{i}'] = down_i(vars[f'x{"" if i == 1 else i - 1}'])
            print("down", vars[f'x{i}'].shape)

        vars[f'u{len(downs)}'] = downs[-1](vars[f'x{len(downs) - 1}'])
        for i, up_i in enumerate(ups):
            vars[f'u{len(ups) - i}'] = up_i(vars[f'x{len(ups) - i}'], vars[f'u{len(ups) - i + 1}'])
            print("up", vars[f'u{len(ups) - i}'].shape)

        vars['res'] = final(vars['u1'])
        vars['res_sigmoid'] = sigmoid(vars['res'])
        return vars

    def forward(self, x):
        return self.forward_vars(x)['res_sigmoid']

    def enc_block(self, ch_in, ch_out, max_pool=True):
        layers = [nn.MaxPool2d(kernel_size=2)] if max_pool else []
        layers += [
            nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1),
            nn.ReLU(),
        ]
        return nn.Sequential(*layers)

    def get_last_block_inputs(self, x):
        with torch.no_grad():
            vars = self.forward_vars(x)

            return vars['x1'], vars['u2']


In [8]:
root = "small_data"  # os.path.join("data")
device = "cuda"  # torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = tv.transforms.Compose([
    # tv.transforms.RandomAffine(0, translate=(5/96, 5/96), fill=(255,255,255)),
    # tv.transforms.ColorJitter(hue=0.5),
    tv.transforms.RandomHorizontalFlip(p=0.5),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.5, 0.5, 0.5,), (0.5, 0.5, 0.5,))
])
dataset = ImageFolder(
    root=root,
    transform=transform
)
dataloader = DataLoader(dataset,
                        batch_size=16,
                        shuffle=True,
                        num_workers=1,
                        drop_last=True
                        )

In [9]:
img = next(iter(dataloader))[0].to(device)
img.shape

torch.Size([16, 3, 96, 96])

In [11]:
unet = Unet([8, 16, 32, 64, 128]).to(device)
unet(img).shape

down torch.Size([16, 8, 96, 96])
down torch.Size([16, 16, 48, 48])
down torch.Size([16, 32, 24, 24])
down torch.Size([16, 64, 12, 12])
up torch.Size([16, 64, 12, 12])
up torch.Size([16, 32, 24, 24])
up torch.Size([16, 16, 48, 48])
up torch.Size([16, 8, 96, 96])


torch.Size([16, 3, 96, 96])

In [None]:
class MyGeneratorUpsampling(nn.Module):
    # input is (batch, latent_dim, 1, 1)
    def __init__(self, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim

        channels = [latent_dim, 64 * 8, 64 * 8, 64 * 4, 64 * 4, 64 * 2, 64, 64, 3]
        params = [3, 2, 2, 2, 1, 2, 2, 1]
        self.layers = nn.Sequential()
        for i, (cur_channels, next_channels, scale) in enumerate(zip(channels[:-1], channels[1:], params)):
            if scale > 1:
                self.layers.append(nn.Upsample(scale_factor=scale))
            self.layers.append(nn.Conv2d(cur_channels, next_channels, kernel_size=3, stride=1, padding=1))
            self.layers.append(*[nn.Conv2d(next_channels, next_channels, kernel_size=3, padding=1),
                                 nn.BatchNorm2d(next_channels)
                                 nn.LeakyReLU(),
                                 nn.Conv2d(next_channels, next_channels, kernel_size=3, padding=1),
                                 nn.BatchNorm2d(next_channels),
                                 nn.LeakyReLU()
                                 ])

            if i == len(params) - 1:
                self.layers.append(nn.Tanh())
            else:
                self.layers.append(nn.BatchNorm2d(next_channels))
                self.layers.append(nn.LeakyReLU())

    def forward(self, x):
        if len(x.shape) == 2:
            x = x[:, :, None, None]

        return self.layers(x)