In [15]:
import os
import time
import torch
import torch.nn as nn
import argparse
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from collections import defaultdict

In [16]:
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--learning_rate", type=float, default=0.001)
parser.add_argument("--encoder_layer_sizes", type=list, default=[784, 256])
parser.add_argument("--decoder_layer_sizes", type=list, default=[256, 784])
parser.add_argument("--latent_size", type=int, default=2)
parser.add_argument("--print_every", type=int, default=100)
parser.add_argument("--fig_root", type=str, default='figs')
parser.add_argument("--conditional", default=True, action='store_true')

args = parser.parse_known_args()[0]

In [17]:
def idx2onehot(idx, n):
    # idx = y， 是label
    # n = 10. 因为label范围是0~9
    assert torch.max(idx).item() < n
    if idx.dim() == 1:
        # shape从[64]的list变成[64,1]
        idx = idx.unsqueeze(1)

    # shape: [64, 10]
    onehot = torch.zeros(idx.size(0), n)
    # 按行填充
    # idx: shape为[64, 1]的0~9的阵
    # 按idx的index填1
    onehot.scatter_(1, idx, 1)

    return onehot

In [27]:
idx = torch.Tensor([0, 1, 2 ,3])
#idx = idx.unsqueeze(1)
print(idx.size())

torch.Size([4])


In [18]:
class Encoder(nn.Module):

    def __init__(self, layer_sizes, latent_size, conditional, num_labels):

        super().__init__()

        self.conditional = conditional
        if self.conditional:
            # layer_size :　［794, 256］
            layer_sizes[0] += num_labels
            print("encoder layer_size: " + str(layer_sizes[0]))

        self.MLP = nn.Sequential()
        # layer_size[:-1]: [794] 
        # layer_size[1:]：[256]　
        # zip: [(794, 256)]
        for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
            self.MLP.add_module(
                name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
            self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())

        # layer_sizes[-1] : 256
        self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
        self.linear_log_var = nn.Linear(layer_sizes[-1], latent_size)

    def forward(self, x, c=None):

        if self.conditional:
            #c = y， 是label
            # 转换之后c的shape: [64, 10]
            c = idx2onehot(c.cpu(), n=10).to(device)
            # 拼接后x的shape：[64, 794]
            x = torch.cat((x, c), dim=-1).to(device)

        # x shape:[64, 256]
        x = self.MLP(x)

        # means shape: [64, 2]
        means = self.linear_means(x)
        # log_vars shape: [64, 2]
        log_vars = self.linear_log_var(x)

        return means, log_vars

In [19]:
class Decoder(nn.Module):

    def __init__(self, layer_sizes, latent_size, conditional, num_labels):

        super().__init__()

        self.MLP = nn.Sequential()

        self.conditional = conditional
        if self.conditional:
            input_size = latent_size + num_labels
        else:
            input_size = latent_size

        # layer_size: [256, 784]
        # layer_size[:-1]:256
        # zip : [(12, 256), (256, 784)]
        for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
            self.MLP.add_module(
                name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
            if i+1 < len(layer_sizes):
                self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
            else:
                self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())

    def forward(self, z, c):

        if self.conditional:  
            c = idx2onehot(c.cpu(), n=10).to(device)
            z = torch.cat((z, c), dim=-1).to(device)

        x = self.MLP(z)

        return x

In [20]:
class VAE(nn.Module):

    def __init__(self, encoder_layer_sizes, latent_size, decoder_layer_sizes,
                 conditional=False, num_labels=0):

        super().__init__()

        if conditional:
            assert num_labels > 0

        assert type(encoder_layer_sizes) == list
        assert type(latent_size) == int
        assert type(decoder_layer_sizes) == list

        self.latent_size = latent_size

        self.encoder = Encoder(
            encoder_layer_sizes, latent_size, conditional, num_labels)
        self.decoder = Decoder(
            decoder_layer_sizes, latent_size, conditional, num_labels)

    def forward(self, x, c=None):

        if x.dim() > 2:
            # x shape: [64, 784]
            x = x.view(-1, 28*28)
        
        # batch_size: 64
        batch_size = x.size(0)
        # encoder 
        means, log_var = self.encoder(x, c)
        # reparameter: q_\phi(z|x, y)
        std = torch.exp(0.5 * log_var)
        # eps shape: [64, 2]
        eps = torch.randn([batch_size, self.latent_size]).cuda()
        # std shape: [64, 2]
        z = eps * std + means
        # decoder: p(x|z, y)
        recon_x = self.decoder(z, c)

        return recon_x, means, log_var, z

    def inference(self, n=1, c=None):
        # p_theta(z)
        batch_size = n
        # 假设z是x encode过来的, 即p_\theta(z|x，y)
        # 那么这里就是假设p_theta(z|x)为标准正态分布
        # 可以推得p_theta(x)为标准正态分布
        z = torch.randn([batch_size, self.latent_size]).cuda()

        # p_\theta(x|z, y)
        recon_x = self.decoder(z, c)

        return recon_x


In [21]:
vae = VAE(
    encoder_layer_sizes=args.encoder_layer_sizes,
    latent_size=args.latent_size,
    decoder_layer_sizes=args.decoder_layer_sizes,
    conditional=args.conditional,
    num_labels=10 if args.conditional else 0)
print(vae)

encoder layer_size: 794
VAE(
  (encoder): Encoder(
    (MLP): Sequential(
      (L0): Linear(in_features=794, out_features=256, bias=True)
      (A0): ReLU()
    )
    (linear_means): Linear(in_features=256, out_features=2, bias=True)
    (linear_log_var): Linear(in_features=256, out_features=2, bias=True)
  )
  (decoder): Decoder(
    (MLP): Sequential(
      (L0): Linear(in_features=12, out_features=256, bias=True)
      (A0): ReLU()
      (L1): Linear(in_features=256, out_features=784, bias=True)
      (sigmoid): Sigmoid()
    )
  )
)


In [30]:
dataset = MNIST(
    root='../data', train=True, transform=transforms.ToTensor(),
    download=True)
data_loader = DataLoader(
    dataset=dataset, batch_size=args.batch_size, shuffle=True)

i, (data, label) = next(enumerate(data_loader))
# print(i, data)
print(label.size())
print(data.size())

torch.Size([64])
torch.Size([64, 1, 28, 28])


In [8]:
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args.seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ts = time.time()

dataset = MNIST(
    root='../data', train=True, transform=transforms.ToTensor(),
    download=True)
data_loader = DataLoader(
    dataset=dataset, batch_size=args.batch_size, shuffle=True)

def loss_fn(recon_x, x, mean, log_var):
    # recon_x: q_\phi(x|z, y)
    # x: p_\theta(x|z, y) = p_\theta(x) 为标准正态分布
    # 为什么论文里写的是p_\theta(y|x, z), z服从q_\phi(z|x, y)
    # 这里算的是p_\theta(x|z, y), z服从q_\phi(z|x, y)
    # 可以理解为算出recon_x 里包含了y？
    BCE = torch.nn.functional.binary_cross_entropy(
        recon_x.view(-1, 28*28), x.view(-1, 28*28), reduction='sum')
    # DKL(q_\phi(z|x, y)||p(z|y))
    KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())

    return (BCE + KLD) / x.size(0)

vae = VAE(
    encoder_layer_sizes=args.encoder_layer_sizes,
    latent_size=args.latent_size,
    decoder_layer_sizes=args.decoder_layer_sizes,
    conditional=args.conditional,
    num_labels=10 if args.conditional else 0).to(device)

optimizer = torch.optim.Adam(vae.parameters(), lr=args.learning_rate)

logs = defaultdict(list)

for epoch in range(args.epochs):

    tracker_epoch = defaultdict(lambda: defaultdict(dict))

    for iteration, (x, y) in enumerate(data_loader):

        x, y = x.to(device), y.to(device)

        if args.conditional:
            # x shape: [64, 1, 28, 28], y shape: [64]
            recon_x, mean, log_var, z = vae(x, y)
        else:
            recon_x, mean, log_var, z = vae(x)

        for i, yi in enumerate(y):
            id = len(tracker_epoch)
            tracker_epoch[id]['x'] = z[i, 0].item()
            tracker_epoch[id]['y'] = z[i, 1].item()
            tracker_epoch[id]['label'] = yi.item()

        loss = loss_fn(recon_x, x, mean, log_var)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logs['loss'].append(loss.item())

        if iteration % args.print_every == 0 or iteration == len(data_loader)-1:
            print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format(
                epoch, args.epochs, iteration, len(data_loader)-1, loss.item()))

            if args.conditional:
                c = torch.arange(0, 10).long().unsqueeze(1)
                # decoder: p_\theta(x|z, c)
                x = vae.inference(n=c.size(0), c=c)
            else:
                x = vae.inference(n=10)

            plt.figure()
            plt.figure(figsize=(5, 10))
            for p in range(10):
                plt.subplot(5, 2, p+1)
                if args.conditional:
                    plt.text(
                        0, 0, "c={:d}".format(c[p].item()), color='black',
                        backgroundcolor='white', fontsize=8)
                x = x.cpu()
                plt.imshow(x[p].view(28, 28).data.numpy())
                plt.axis('off')

            if not os.path.exists(os.path.join(args.fig_root, str(ts))):
                if not(os.path.exists(os.path.join(args.fig_root))):
                    os.mkdir(os.path.join(args.fig_root))
                os.mkdir(os.path.join(args.fig_root, str(ts)))

            plt.savefig(
                os.path.join(args.fig_root, str(ts),
                             "E{:d}I{:d}.png".format(epoch, iteration)),
                dpi=300)
            plt.clf()
            plt.close('all')

    df = pd.DataFrame.from_dict(tracker_epoch, orient='index')
    g = sns.lmplot(
        x='x', y='y', hue='label', data=df.groupby('label').head(100),
        fit_reg=False, legend=True)
    g.savefig(os.path.join(
        args.fig_root, str(ts), "E{:d}-Dist.png".format(epoch)),
        dpi=300)

encoder layer_size: 804


RuntimeError: size mismatch, m1: [64 x 794], m2: [804 x 256] at C:/w/1/s/tmp_conda_3.6_171155/conda/conda-bld/pytorch_1570813991702/work/aten/src\THC/generic/THCTensorMathBlas.cu:290