In [1]:
import torch
import torch.nn as nn
from collections import OrderedDict
import numpy as np

In [2]:
%run utils.ipynb
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)


class MLP(nn.Module):
    def __init__(self, hidden_size, last_activation=True):
        super(MLP, self).__init__()
        q = []
        for i in range(len(hidden_size) - 1):
            in_dim = hidden_size[i]
            out_dim = hidden_size[i + 1]
            q.append(("Linear_%d" % i, nn.Linear(in_dim, out_dim)))
            if (i < len(hidden_size) - 2) or ((i == len(hidden_size) - 2) and (last_activation)):
                q.append(("BatchNorm_%d" % i, nn.BatchNorm1d(out_dim)))
                q.append(("ReLU_%d" % i, nn.ReLU(inplace=True)))
        self.mlp = nn.Sequential(OrderedDict(q))

    def forward(self, x):
        return self.mlp(x)

In [5]:
class Encoder(nn.Module):
    def __init__(self, shape, nhid=16, ncond=0):
        super(Encoder, self).__init__()
        c, h, w = shape
        ww = ((w - 8) // 2 - 4) // 2
        hh = ((h - 8) // 2 - 4) // 2
        self.encode = nn.Sequential(nn.Conv2d(c, 16, 5, padding=0), nn.BatchNorm2d(16), nn.ReLU(inplace=True),
                                    nn.Conv2d(16, 32, 5, padding=0), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
                                    nn.MaxPool2d(2, 2),
                                    nn.Conv2d(32, 64, 3, padding=0), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
                                    nn.Conv2d(64, 64, 3, padding=0), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
                                    nn.MaxPool2d(2, 2),
                                    Flatten(), MLP([ww * hh * 64, 256, 128])
                                    )
        self.calc_mean = MLP([128 + ncond, 64, nhid], last_activation=False)
        self.calc_logvar = MLP([128 + ncond, 64, nhid], last_activation=False)

    def forward(self, x, y=None):
        x = self.encode(x)
        if (y is None):
            return self.calc_mean(x), self.calc_logvar(x)
        else:
            return self.calc_mean(torch.cat((x, y), dim=1)), self.calc_logvar(torch.cat((x, y), dim=1))


class Decoder(nn.Module):
    def __init__(self, shape, nhid=16, ncond=0):
        super(Decoder, self).__init__()
        c, w, h = shape
        self.shape = shape
        self.decode = nn.Sequential(MLP([nhid + ncond, 64, 128, 256, c * w * h], last_activation=False), nn.Sigmoid())

    def forward(self, z, y=None):
        c, w, h = self.shape
        if (y is None):
            return self.decode(z).view(-1, c, w, h)
        else:
            return self.decode(torch.cat((z, y), dim=1)).view(-1, c, w, h)
class ResidualBlock(nn.Module):
    def __init__(self, input_channels, kernel_size):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels, kernel_size, padding=1)
        self.bn    = nn.BatchNorm2d(input_channels)
        self.relu  = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.bn(y)
        y = self.relu(y)

        return y + x
# class SRCNN(nn.Module):
#     def __init__(self):
#         super(SRCNN, self).__init__()
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=2, padding_mode='replicate')
#
#         self.res1  = ResidualBlock(64, 3)
#         self.res2  = ResidualBlock(64, 3)
#         self.res3  = ResidualBlock(64, 3)
#         self.res4  = ResidualBlock(64, 3)
#         self.res5  = ResidualBlock(64, 3)
#
#         self.conv_int = nn.Conv2d(64, 64, kernel_size=3, padding=1)
#         self.tanh     = nn.Tanh()
#
#         self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=2, padding_mode='replicate')
#         self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2, padding_mode='replicate')
#
#     def forward(self, x):
#
#         x = self.conv1(x)
#         # chaining the residuals
#         y = self.res1(x)
#         y = self.res2(y)
#         y = self.res3(y)
#         y = self.res4(y)
#         y = self.res5(y)
#
#         y = self.conv_int(y)
#         y = y + x
#
#         y = self.conv2(y)
#         y = self.conv3(y)
#
#         y = self.tanh(y)
#         return y

In [6]:

class cVAE(nn.Module):
    def __init__(self, shape, nclass, nhid=16, ncond=16):
        super(cVAE, self).__init__()
        self.dim = nhid
        self.encoder = Encoder(shape, nhid, ncond=ncond)
        self.decoder = Decoder(shape, nhid, ncond=ncond)
        self.label_embedding = nn.Embedding(nclass, ncond)
        # self.sr = SRCNN()

    def sampling(self, mean, logvar):
        eps = torch.randn(mean.shape).to(device)
        sigma = 0.5 * torch.exp(logvar)
        return mean + eps * sigma

    def forward(self, x, y):
        y = self.label_embedding(y)
        mean, logvar = self.encoder(x, y)
        z = self.sampling(mean, logvar)
        w = self.decoder(z, y)
        return w, mean, logvar

    def generate(self, class_idx):
        if (type(class_idx) is int):
            class_idx = torch.tensor(class_idx)
        class_idx = class_idx.to(device)
        if (len(class_idx.shape) == 0):
            batch_size = None
            class_idx = class_idx.unsqueeze(0)
            z = torch.randn((1, self.dim)).to(device)
        else:
            batch_size = class_idx.shape[0]
            z = torch.randn((batch_size, self.dim)).to(device)
        y = self.label_embedding(class_idx)
        res = self.decoder(z, y)
        if not batch_size:
            res = res.squeeze(0)
        return res