In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchviz import make_dot
from torchinfo import summary
import numpy as np

In [9]:
print(torch.cuda.is_available())

True


In [10]:
print(torch.cuda.device_count())

1


In [21]:
class Encoder(nn.Module):
    """Encoder"""

    def __init__(self, z_dim=32):
        super().__init__()
        self.z_dim = z_dim

        self.conv1 = nn.Conv2d(3, 32, 4, stride=2)
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
        self.conv3 = nn.Conv2d(64, 128, 4, stride=2) 
        self.conv4 = nn.Conv2d(128, 256, 4, stride=2)

        self.mu = nn.Linear(256*2*2, z_dim)
        self.logvar = nn.Linear(256*2*2, z_dim)

        self.relu = nn.ReLU()

    def forward(self, x):
        # x: [batch_size, 3, 64, 64]
        c1 = self.conv1(x) # (64x64x3) -> (31x31x32)
        h1 = self.relu(c1) # (31x31x32) -> (31x31x32)
        c2 = self.conv2(h1) # (31x31x32) -> (14x14x64)
        h2 = self.relu(c2) # (14x14x64) -> (14x14x64)
        c3 = self.conv3(h2) # (14x14x64) -> (6x6x128)
        h3 = self.relu(c3) # (6x6x128) -> (6x6x128)
        c4 = self.conv4(h3) # (6x6x128) -> (2x2x256)
        h4 = self.relu(c4) # (2x2x256) -> (2x2x256)

        d1 = h4.view(-1, 256*2*2) # (2x2x256) -> (1024)

        mu = self.mu(d1) # (1024) -> (32)
        logvar = self.logvar(d1) # (1024) -> (32)
        var = torch.exp(logvar)
        std = torch.sqrt(var)

        ep = torch.randn_like(std)

        z = mu + ep * std # (32) -> (32)

        return z, mu, logvar

In [22]:
class Decoder(nn.Module):
    """Decoder"""
    def __init__(self, z_dim=32):
        super().__init__()
        self.z_dim = z_dim

        self.l1 = nn.Linear(z_dim, 1024*1*1)

        self.deconv1 = nn.ConvTranspose2d(1024, 128, 5, stride=2)
        self.deconv2 = nn.ConvTranspose2d(128, 64, 5, stride=2)
        self.deconv3 = nn.ConvTranspose2d(64, 32, 6, stride=2)
        self.deconv4 = nn.ConvTranspose2d(32, 3, 6, stride=2)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, z):
        # z: [batch_size, 32]
        d1 = self.l1(z) # (32) -> (1024)
        d1 = d1.view(-1, 1024, 1, 1) # (1024) -> (1x1x1024)

        dc1 = self.deconv1(d1) # (1x1x1024) -> (5x5x128)
        h1 = self.relu(dc1) # (5x5x128) -> (5x5x128)
        dc2 = self.deconv2(h1) # (5x5x128) -> (13x13x64)
        h2 = self.relu(dc2) # (13x13x64) -> (13x13x64)
        dc3 = self.deconv3(h2) # (13x13x64) -> (30x30x32)
        h3 = self.relu(dc3) # (30x30x32) -> (30x30x32)
        dc4 = self.deconv4(h3) # (30x30x32) -> (64x64x3)
        x = self.sigmoid(dc4) # (64x64x3) -> (64x64x3)

        return x

In [25]:
class VAE(nn.Module):
    '''VAE'''
    def __init__(self, z_dim=32):
        super().__init__()
        self.z_dim = z_dim

        self.encoder = Encoder(z_dim)
        self.decoder = Decoder(z_dim)

    def forward(self, x):
        z, mu, logvar = self.encoder(x)
        x = self.decoder(z)

        return x, mu, logvar

In [17]:
encoder = Encoder()

In [18]:
summary(encoder, (3, 64, 64))

Layer (type:depth-idx)                   Output Shape              Param #
Encoder                                  [1, 32]                   --
├─Conv2d: 1-1                            [32, 31, 31]              1,568
├─ReLU: 1-2                              [32, 31, 31]              --
├─Conv2d: 1-3                            [64, 14, 14]              32,832
├─ReLU: 1-4                              [64, 14, 14]              --
├─Conv2d: 1-5                            [128, 6, 6]               131,200
├─ReLU: 1-6                              [128, 6, 6]               --
├─Conv2d: 1-7                            [256, 2, 2]               524,544
├─ReLU: 1-8                              [256, 2, 2]               --
├─Linear: 1-9                            [1, 32]                   32,800
├─Linear: 1-10                           [1, 32]                   32,800
Total params: 755,744
Trainable params: 755,744
Non-trainable params: 0
Total mult-adds (M): 400.37
Input size (MB): 0.05
Forward/

In [23]:
decoder = Decoder()

In [24]:
summary(decoder, (32,))

Layer (type:depth-idx)                   Output Shape              Param #
Decoder                                  [1, 3, 64, 64]            --
├─Linear: 1-1                            [1024]                    33,792
├─ConvTranspose2d: 1-2                   [1, 128, 5, 5]            3,276,928
├─ReLU: 1-3                              [1, 128, 5, 5]            --
├─ConvTranspose2d: 1-4                   [1, 64, 13, 13]           204,864
├─ReLU: 1-5                              [1, 64, 13, 13]           --
├─ConvTranspose2d: 1-6                   [1, 32, 30, 30]           73,760
├─ReLU: 1-7                              [1, 32, 30, 30]           --
├─ConvTranspose2d: 1-8                   [1, 3, 64, 64]            3,459
├─Sigmoid: 1-9                           [1, 3, 64, 64]            --
Total params: 3,592,803
Trainable params: 3,592,803
Non-trainable params: 0
Total mult-adds (M): 231.70
Input size (MB): 0.00
Forward/backward pass size (MB): 0.45
Params size (MB): 14.37
Estimated Tota

In [55]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vae = VAE()

In [32]:
summary(vae, (3, 64, 64))

Layer (type:depth-idx)                   Output Shape              Param #
VAE                                      [1, 3, 64, 64]            --
├─Encoder: 1-1                           [1, 32]                   --
│    └─Conv2d: 2-1                       [32, 31, 31]              1,568
│    └─ReLU: 2-2                         [32, 31, 31]              --
│    └─Conv2d: 2-3                       [64, 14, 14]              32,832
│    └─ReLU: 2-4                         [64, 14, 14]              --
│    └─Conv2d: 2-5                       [128, 6, 6]               131,200
│    └─ReLU: 2-6                         [128, 6, 6]               --
│    └─Conv2d: 2-7                       [256, 2, 2]               524,544
│    └─ReLU: 2-8                         [256, 2, 2]               --
│    └─Linear: 2-9                       [1, 32]                   32,800
│    └─Linear: 2-10                      [1, 32]                   32,800
├─Decoder: 1-2                           [1, 3, 64, 64]     

In [57]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

In [51]:
transform_data = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

In [58]:
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform_data)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform_data)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100.0%


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [60]:
def criterion(predict, target, ave, log_dev):
  bce_loss = F.binary_cross_entropy(predict, target, reduction='sum')
  kl_loss = -0.5 * torch.sum(1 + log_dev - ave**2 - log_dev.exp())
  loss = bce_loss + kl_loss
  return loss

net = VAE().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
history = {'train_loss': []}

epochs = 10

for epoch in range(epochs):
    for i, (img, _) in enumerate(train_loader):
        img = img.to(device)
        output, mu, logvar = net(img)
        loss = criterion(output, img, mu, logvar)
        history['train_loss'].append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f'epoch: {epoch}, step: {i}, loss: {loss.item()}')

epoch: 0, step: 0, loss: 1090587.75
epoch: 0, step: 100, loss: 984764.5
epoch: 0, step: 200, loss: 972798.8125
epoch: 0, step: 300, loss: 968391.5
epoch: 1, step: 0, loss: 962117.5625
epoch: 1, step: 100, loss: 955147.0
epoch: 1, step: 200, loss: 937359.75
epoch: 1, step: 300, loss: 923683.75
epoch: 2, step: 0, loss: 929104.25
epoch: 2, step: 100, loss: 924408.5625
epoch: 2, step: 200, loss: 922848.875
epoch: 2, step: 300, loss: 936311.0625
epoch: 3, step: 0, loss: 935280.3125
epoch: 3, step: 100, loss: 909964.25
epoch: 3, step: 200, loss: 921540.75
epoch: 3, step: 300, loss: 916599.75
epoch: 4, step: 0, loss: 916079.4375
epoch: 4, step: 100, loss: 925119.1875
epoch: 4, step: 200, loss: 924801.125
epoch: 4, step: 300, loss: 924342.8125
epoch: 5, step: 0, loss: 934412.5
epoch: 5, step: 100, loss: 940165.1875
epoch: 5, step: 200, loss: 922508.3125
epoch: 5, step: 300, loss: 910573.3125
epoch: 6, step: 0, loss: 911282.0
epoch: 6, step: 100, loss: 925865.1875
epoch: 6, step: 200, loss: 940

In [62]:
g = make_dot(output, params=dict(net.named_parameters()))
g.view()

'Digraph.gv.pdf'