In [2]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from tqdm import tqdm

from pixyz.distributions import Normal, Bernoulli, RelaxedCategorical, Categorical

batch_size = 128
epochs = 10

if torch.cuda.is_available():plt.style.use("ggplot")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = "cuda"
else:
    device = "cpu"

In [None]:
class Encoder(Normal):
    def __init__(self, z_dim=63, label_num=10):
        super(Encoder, self).__init__(cond_var=["x", "y"], var=["z"])

        self.z_dim = z_dim

        # encode
        self.conv_e = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),  # 28 ⇒ 14
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 14 ⇒ 7
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        self.fc1 = nn.Sequential(
            nn.Linear(128 * 7 *7,  40),
        )      
        self.fc2 = nn.Sequential(
            nn.Linear(128 * 7 * 7,  label_num),
        )        
        
        self.fc = nn.Sequential(
            nn.Linear(40+label_num, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 2*self.z_dim),
        )

    def forward(self, x, y):
        x = self.conv_e(x)
        x = x.view(-1, 128 * 7 * 7)
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        x = torch.cat([x1, x2*y], dim=1)
        x = self.fc(x)
        mu = x[:, :self.z_dim]
        scale = F.softplus(x[:, self.z_dim:])
        return {"loc": mu, "scale": scale}


class Decoder(Bernoulli):
    def __init__(self, z_dim=63, label_num=10):
        super(Decoder, self).__init__(cond_var=["z", "y"], var=["x"])
        
        self.z_dim = z_dim 

        # decode
        self.fc1 = nn.Sequential(
            nn.Linear(self.z_dim, 40),
        )
        self.fc2 = nn.Sequential(
            nn.Linear(self.z_dim, label_num),
        )
        
        self.fc_d = nn.Sequential(
            nn.Linear(40+label_num, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 128 * 7 * 7),
            nn.LeakyReLU(0.2)
        )
        self.conv_d = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, z, y):
        z1 = self.fc1(z)
        z2 = self.fc2(z)
        z = torch.cat([z1, z2*y], dim=1)
        h = self.fc_d(z)
        h = h.view(-1, 128, 7, 7)
        return {"probs": self.conv_d(h)}
    
    
# classifier p(y|x)
class Classifier(RelaxedCategorical):
    def __init__(self):
        super(Classifier, self).__init__(cond_var=["x"], var=["y"], temperature=0.1, name="p")
        
        # encode
        self.conv_e = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),  # 28 ⇒ 14
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 14 ⇒ 7
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        
        self.fc = nn.Sequential(
            nn.Linear(128*7*7, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, y_dim),
        )
    def forward(self, x):
        x = self.conv(x)
        h = self.fc(x)
        return {"probs": h}
    