# Variational Autoencoder -  Aerial Images Simulation

## General Imports

In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
from torchvision.utils import save_image

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


## Helper Functions

In [3]:
def conv2d_output_dimensions(width_in, height_in, input_channels, kernel_size, kernel_filters, padding=0, stride=1):
    width_out = ((width_in - kernel_size + (2 * padding)) / stride) + 1
    height_out = ((height_in - kernel_size + (2 * padding)) / stride) + 1
    output_channels = kernel_filters
    
    return width_out, height_out, output_channels

In [4]:
def maxpool_output_dimensions(width_in, height_in, input_channels, kernel_size, kernel_filters, padding=0, stride=1, dilation=1):
    width_out = ((width_in - (dilation * kernel_size) + (2 * padding)) / stride) + 1
    height_out = ((height_in - (dilation * kernel_size) + (2 * padding)) / stride) + 1
    output_channels = input_channels
    
    return width_out, height_out, output_channels

In [5]:
def convTranspose2d_dim(width_in, height_in, input_channels, kernel_size, kernel_filters, padding=0, stride=1):
    output_channels = kernel_filters
    width_out = stride * (width_in - 1) + kernel_size - (2 * padding)
    height_out = stride * (height_in - 1) + kernel_size - (2 * padding)
    
    return width_out, height_out, output_channels

In [6]:
def conv2d_output_dimensions_pytorch(width_in, height_in, input_channels, kernel_size, kernel_filters, padding=0, stride=1, dilation=1):
    width_out = ((width_in - (dilation * (kernel_size - 1)) + (2 * padding) - 1) / stride) + 1
    height_out = ((height_in - (dilation * (kernel_size - 1))  + (2 * padding) - 1) / stride) + 1
    output_channels = kernel_filters
    
    return width_out, height_out, output_channels

In [7]:
def convTranspose2d_dim__pytorch(width_in, height_in, input_channels, kernel_size, kernel_filters, padding=0, stride=1, dilation=1, output_pad=0):
    output_channels = kernel_filters
    width_out = stride * (width_in - 1) + (dilation * (kernel_size - 1)) - (2 * padding) + output_pad + 1
    height_out = stride * (height_in - 1) + (dilation * (kernel_size - 1)) - (2 * padding) + output_pad + 1
    
    return width_out, height_out, output_channels

## Resnet dimension computation

### Test residual blocks

Test Input and output dimensions for residual blocks

In [8]:
conv2d_output_dimensions(121, 121, 3, 3, 32, stride=2, padding=1)

(61.0, 61.0, 32)

In [9]:
conv2d_output_dimensions(61, 61, 3, 3, 32, stride=1, padding=1)

(61.0, 61.0, 32)

In [10]:
conv2d_output_dimensions(121, 121, 3, 1, 32, stride=2, padding=0)

(61.0, 61.0, 32)

### Aerial Images 120x120 Encoder

#### Input

In [11]:
conv2d_output_dimensions(121, 121, 3, 5, 32, stride=1, padding=2)

(121.0, 121.0, 32)

#### Residual Block 1

In [12]:
conv2d_output_dimensions(121, 121, 3, 3, 32, stride=2, padding=1)

(61.0, 61.0, 32)

In [13]:
conv2d_output_dimensions(61, 61, 32, 3, 32, stride=1, padding=1)

(61.0, 61.0, 32)

In [14]:
conv2d_output_dimensions(121, 121, 3, 1, 32, padding=0, stride=2)

(61.0, 61.0, 32)

#### Residual Block 2

In [15]:
conv2d_output_dimensions(61, 61, 32, 3, 64, stride=2, padding=1)

(31.0, 31.0, 64)

In [16]:
conv2d_output_dimensions(31, 31, 64, 3, 64, stride=1, padding=1)

(31.0, 31.0, 64)

In [17]:
conv2d_output_dimensions(61, 61, 32, 1, 64, stride=2, padding=0)

(31.0, 31.0, 64)

#### Residual Block 3

In [18]:
conv2d_output_dimensions(31, 31, 64, 3, 128, stride=2, padding=1)

(16.0, 16.0, 128)

In [19]:
conv2d_output_dimensions(16, 16, 128, 3, 128, stride=1, padding=1)

(16.0, 16.0, 128)

In [20]:
conv2d_output_dimensions(31, 31, 64, 1, 128, stride=2, padding=0)

(16.0, 16.0, 128)

### Aerial Images 120x120 Encoder Pytorch

In [21]:
conv2d_output_dimensions_pytorch(121, 121, 3, 6, 32, stride=1, padding=2)

(120.0, 120.0, 32)

In [22]:
conv2d_output_dimensions_pytorch(120, 120, 32, 4, 32, stride=2, padding=1)

(60.0, 60.0, 32)

In [23]:
conv2d_output_dimensions_pytorch(61, 61, 32, 4, 32, stride=1, padding=1)

(60.0, 60.0, 32)

In [24]:
conv2d_output_dimensions_pytorch(120, 120, 32, 2, 32, stride=2, padding=0)

(60.0, 60.0, 32)

### Aerial Images 120x120 Decoder Pytorch

In [25]:
convTranspose2d_dim__pytorch(15, 15, 128, 4, 128, padding=1, stride=1)

(16, 16, 128)

In [26]:
convTranspose2d_dim__pytorch(16, 16, 128, 4, 64, padding=2, stride=2)

(30, 30, 64)

In [27]:
convTranspose2d_dim__pytorch(30, 30, 64, 4, 64, padding=1, stride=1)

(31, 31, 64)

In [28]:
convTranspose2d_dim__pytorch(31, 31, 64, 4, 32, padding=2, stride=2)

(60, 60, 32)

In [29]:
convTranspose2d_dim__pytorch(60, 60, 32, 4, 32, padding=1, stride=1)

(61, 61, 32)

In [30]:
convTranspose2d_dim__pytorch(61, 61, 32, 4, 32, padding=2, stride=2)

(120, 120, 32)

In [31]:
convTranspose2d_dim__pytorch(120, 120, 32, 4, 3, padding=1, stride=1)

(121, 121, 3)

### Aerial Images 120x120 Decoder

In [32]:
convTranspose2d_dim(16, 16, 128, 4, 128, padding=1, stride=1)

(17, 17, 128)

In [33]:
convTranspose2d_dim(17, 17, 128, 4, 64, padding=2, stride=2)

(32, 32, 64)

In [34]:
convTranspose2d_dim(32, 32, 64, 3, 64, padding=1, stride=1)

(32, 32, 64)

In [35]:
convTranspose2d_dim(32, 32, 64, 3, 32, padding=2, stride=2)

(61, 61, 32)

In [36]:
convTranspose2d_dim(61, 61, 32, 3, 32, padding=1, stride=1)

(61, 61, 32)

In [37]:
convTranspose2d_dim(67, 67, 32, 3, 32, padding=2, stride=2)

(131, 131, 32)

In [38]:
convTranspose2d_dim(119, 119, 32, 6, 3, padding=2, stride=1)

(120, 120, 3)

## Project Configs

In [39]:
# number of subprocesses to use for data loading
num_workers = 10
# how many samples per batch to load
batch_size = 32
# Learning rate
learning_rate = 1e-4
# epochs
epochs = 200

log_interval = 100

## Models

### My Dronet Pytorch Implementation

In [40]:
class Dronet(nn.Module):
    def __init__(self, output_dim=64):
        super(Dronet, self).__init__()
        # super().__init__(img_dims, img_channels, output_dim)       
        # Input layer
        self.conv0 = nn.Conv2d(3, 32, kernel_size=6, stride=1, padding=2)
        
        # Residual Block Layer 1
        self.bn1_1 = nn.BatchNorm2d(32)
        self.conv1_1 = nn.Conv2d(32, 32, kernel_size=4, stride=2, padding=1)
        
        self.bn1_2 = nn.BatchNorm2d(32)
        self.conv1_2 = nn.Conv2d(32, 32, kernel_size=4, stride=1, padding=1)
        
        self.conv1_res_down = nn.Conv2d(32, 32, kernel_size=2, stride=2, padding=0)
        
        # Residual Block Layer 2
        self.bn2_1 = nn.BatchNorm2d(32)
        self.conv2_1 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        
        self.bn2_2 = nn.BatchNorm2d(64)
        self.conv2_2 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=1)
        
        self.conv2_res_down = nn.Conv2d(32, 64, kernel_size=2, stride=2, padding=0)
        
        # Residual Block Layer g3
        self.bn3_1 = nn.BatchNorm2d(64)
        self.conv3_1 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        
        self.bn3_2 = nn.BatchNorm2d(128)
        self.conv3_2 = nn.Conv2d(128, 128, kernel_size=4, stride=1, padding=1)
        
        self.conv3_res_down = nn.Conv2d(64, 128, kernel_size=2, stride=2, padding=0)
        
        # Flatten layer
        # self.flatten = nn.Flatten(16 * 16 * 128)
        
        self.dropout1 = nn.Dropout(p=0.5)
        
        # Fully Connected Layers
        self.fc1 = nn.Linear(15 * 15 * 128, 128)
        self.fc2 = nn.Linear(128, output_dim)
        # self.fc2 = nn.Linear(128, 64)
        # self.fc3 = nn.Linear(64, output_dim)
        
    def forward(self, x):
        # Input
        x0 = self.conv0(x)
        
        ##########################
        # Residual Block Layer 1 #
        ##########################
        x1 = self.bn1_1(x0)
        x1 = F.relu(x1)
        x1 = self.conv1_1(x1)
        
        # padding
        x1 = F.pad(x1, (0,1,0,1), mode='replicate')
        
        x1 = self.bn1_2(x1)
        x1 = F.relu(x1)
        x1 = self.conv1_2(x1)
#         print("conv2 shape:")
#         print(x1.shape)
        
        x2 = self.conv1_res_down(x0)
#         print("conv_res shape:")
#         print(x2.shape)
#         print("")
        
        res1out = torch.add(x1, x2)
        
        ##########################
        # Residual Block Layer 2 #
        ##########################
        x1 = self.bn2_1(res1out)
        x1 = F.relu(x1)
        x1 = self.conv2_1(x1)
        
        # padding
        x1 = F.pad(x1, (0,1,0,1), mode='replicate')
        
        x1 = self.bn2_2(x1)
        x1 = F.relu(x1)
        x1 = self.conv2_2(x1)
        
        x2 = self.conv2_res_down(res1out)
        
        res2out = torch.add(x1, x2)
        
        ##########################
        # Residual Block Layer 3 #
        ##########################
        x1 = self.bn3_1(res2out)
        x1 = F.relu(x1)
        x1 = self.conv3_1(x1)
        
        # padding
        x1 = F.pad(x1, (0,1,0,1), mode='replicate')
        
        x1 = self.bn3_2(x1)
        x1 = F.relu(x1)
        x1 = self.conv3_2(x1)
        
        x2 = self.conv3_res_down(res2out)
        
        res3out = torch.add(x1, x2)
        
#         print("conv_res shape:")
#         print(res3out.shape)
#         print("")
        
        # Flatten
        x = res3out.view(-1, 15 * 15 * 128)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout1(x)
        output = self.fc2(x)
        
                
#         print("conv_res shape:")
#         print(output.shape)
#         print("")
        # x = self.fc2(x)
        # output = self.fc3(x)
        
        return output

In [41]:
model = Dronet()
model

Dronet(
  (conv0): Conv2d(3, 32, kernel_size=(6, 6), stride=(1, 1), padding=(2, 2))
  (bn1_1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv1_1): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn1_2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv1_2): Conv2d(32, 32, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (conv1_res_down): Conv2d(32, 32, kernel_size=(2, 2), stride=(2, 2))
  (bn2_1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2_1): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn2_2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2_2): Conv2d(64, 64, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (conv2_res_down): Conv2d(32, 64, kernel_size=(2, 2), stride=(2, 2))
  (bn3_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (co

### Decoder

In [42]:
class ImageDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Decoder
        self.convT1 = nn.ConvTranspose2d(128, 128, kernel_size=4, stride=1, padding=1)
        self.convT2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=2)
        
        self.convT3 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=1, padding=1)
        self.convT4 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=2)
        
        self.convT5 = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=1, padding=1)
        self.convT6 = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=2)
        
        self.convT7 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=1, padding=1)
    
    def forward(self, x):
        x = x.view(-1, 128, 15, 15)
        x = self.convT1(x)
        x = F.relu(x)
        x = self.convT2(x)
        x = F.relu(x)
        x = self.convT3(x)
        x = F.relu(x)
        x = self.convT4(x)
        x = F.relu(x)
        x = self.convT5(x)
        x = F.relu(x)
        x = self.convT6(x)
        x = F.relu(x)
        x = self.convT7(x)
        
        return x

In [43]:
decoder = ImageDecoder()
decoder

ImageDecoder(
  (convT1): ConvTranspose2d(128, 128, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (convT2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
  (convT3): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (convT4): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
  (convT5): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (convT6): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
  (convT7): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
)

### VAE Aerial Images

In [44]:
class AerialImagesVAE(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.q_img = Dronet()

        # Latent vectors mu and sigma
        self.fc1 = nn.Linear(64,64)
        self.fc2 = nn.Linear(64,64)

        # Sampling vector
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, 15 * 15 * 128)
        
        # Decoder
        self.p_img = ImageDecoder()
        

    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = std.data.new(std.size()).normal_()
        return eps.mul(std).add_(mu)

    def sampler(self, z):
        sampling = self.fc3(z)
        up_sampling = self.fc4(sampling)
        return up_sampling

    def forward(self, x):
        x = self.q_img(x)
        
        mu = self.fc1(x)
        logvar = self.fc2(x)
        
        z = self.reparameterize(mu, logvar)
        
#         sampling = self.fc3(z)
        
#         up_sampling = self.fc4(sampling)
        up_sampling = self.sampler(z)
        
        img = self.p_img(up_sampling)
        
        return img, mu, logvar


In [45]:
model = AerialImagesVAE()
model

AerialImagesVAE(
  (q_img): Dronet(
    (conv0): Conv2d(3, 32, kernel_size=(6, 6), stride=(1, 1), padding=(2, 2))
    (bn1_1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv1_1): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (bn1_2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv1_2): Conv2d(32, 32, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (conv1_res_down): Conv2d(32, 32, kernel_size=(2, 2), stride=(2, 2))
    (bn2_1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2_1): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (bn2_2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2_2): Conv2d(64, 64, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (conv2_res_down): Conv2d(32, 64, kernel_size=(2, 2), stride=(2, 2))
    (bn3_1): BatchNorm2d(64, eps=1e-05, moment

### Loss Function

In [46]:
class VAE_Loss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse_loss = nn.MSELoss(reduction="sum")

    def forward(self, x_recon, x, mu, logvar):
        loss_MSE = self.mse_loss(x_recon, x)
        loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return loss_MSE + loss_KLD

## Data Loaders

In [47]:
root_train = "../data/drone_image_data_1k/Training"
root_test = "../data/drone_image_data_1k/Test"

In [48]:
transform=transforms.Compose([
                            transforms.Pad((0, 0, 1, 1), fill=0, padding_mode='edge'),
                            transforms.ToTensor()
                            ])


In [49]:
train_loader = torch.utils.data.DataLoader(datasets.ImageFolder(root_train, transform=transform),
                                           batch_size=batch_size,
                                           num_workers=num_workers,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(datasets.ImageFolder(root_test, transform=transform),
                                          batch_size=batch_size,
                                          num_workers=num_workers,
                                          shuffle=True)

In [50]:
# create iterator
dataiter = iter(train_loader)
# Get batch samples and labels
batch_images, batch_labels = dataiter.next()

# batch_images_samples = batch_images.to(device)
# batch_labels_samples = batch_labels.to(device)

In [51]:
batch_images, batch_labels = dataiter.next()

In [52]:
batch_images.shape

torch.Size([32, 3, 121, 121])

In [53]:
batch_images[0].shape

torch.Size([3, 121, 121])

## Training & Testing

### Model, Loss Function, and Optimizaer Instances

In [54]:
model = AerialImagesVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
VAE_Loss = VAE_Loss()

### Training

In [55]:
val_losses = []
train_losses = []

In [56]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = VAE_Loss(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                       loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))
    train_losses.append(train_loss / len(train_loader.dataset))

In [57]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += VAE_Loss(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                        recon_batch.view(batch_size, 3, 121, 121)[:n]])
                save_image(comparison.cpu(),
                           'results/Fruit_reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    val_losses.append(test_loss)

In [58]:
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)

====> Epoch: 1 Average loss: 12289.5943
====> Test set loss: 8198.9168
====> Epoch: 2 Average loss: 3538.7580
====> Test set loss: 2264.9236
====> Epoch: 3 Average loss: 1979.9419
====> Test set loss: 1857.4827
====> Epoch: 4 Average loss: 1801.2330
====> Test set loss: 1727.6685
====> Epoch: 5 Average loss: 1666.5278
====> Test set loss: 1524.1664
====> Epoch: 6 Average loss: 1593.2696
====> Test set loss: 1532.5765
====> Epoch: 7 Average loss: 1524.3122
====> Test set loss: 1433.1661
====> Epoch: 8 Average loss: 1454.4642
====> Test set loss: 1510.6218
====> Epoch: 9 Average loss: 1394.5182
====> Test set loss: 1264.1460
====> Epoch: 10 Average loss: 1355.7236
====> Test set loss: 1243.2261
====> Epoch: 11 Average loss: 1362.8789
====> Test set loss: 1464.9744
====> Epoch: 12 Average loss: 1297.2402
====> Test set loss: 1256.9643
====> Epoch: 13 Average loss: 1220.5796
====> Test set loss: 1173.5849
====> Epoch: 14 Average loss: 1222.8687
====> Test set loss: 1184.5258
====> Epoch: 1

====> Epoch: 72 Average loss: 723.5725
====> Test set loss: 674.5896
====> Epoch: 73 Average loss: 711.7749
====> Test set loss: 669.1487
====> Epoch: 74 Average loss: 721.3711
====> Test set loss: 671.9069
====> Epoch: 75 Average loss: 709.1593
====> Test set loss: 684.5305
====> Epoch: 76 Average loss: 705.9762
====> Test set loss: 671.9285
====> Epoch: 77 Average loss: 705.2874
====> Test set loss: 659.3584
====> Epoch: 78 Average loss: 703.7578
====> Test set loss: 667.7802
====> Epoch: 79 Average loss: 691.2654
====> Test set loss: 651.5913
====> Epoch: 80 Average loss: 697.7153
====> Test set loss: 673.4347
====> Epoch: 81 Average loss: 691.0898
====> Test set loss: 654.5075
====> Epoch: 82 Average loss: 685.9381
====> Test set loss: 652.1848
====> Epoch: 83 Average loss: 672.4540
====> Test set loss: 644.8919
====> Epoch: 84 Average loss: 679.6031
====> Test set loss: 640.4161
====> Epoch: 85 Average loss: 672.9035
====> Test set loss: 674.8247
====> Epoch: 86 Average loss: 669.

====> Test set loss: 579.2203
====> Epoch: 143 Average loss: 573.5310
====> Test set loss: 569.7326
====> Epoch: 144 Average loss: 566.2387
====> Test set loss: 584.1313
====> Epoch: 145 Average loss: 569.0739
====> Test set loss: 594.9425
====> Epoch: 146 Average loss: 566.4002
====> Test set loss: 575.7249
====> Epoch: 147 Average loss: 572.5061
====> Test set loss: 581.1506
====> Epoch: 148 Average loss: 553.6140
====> Test set loss: 575.7944
====> Epoch: 149 Average loss: 551.2899
====> Test set loss: 567.5969
====> Epoch: 150 Average loss: 557.2702
====> Test set loss: 586.6700
====> Epoch: 151 Average loss: 551.5871
====> Test set loss: 564.6134
====> Epoch: 152 Average loss: 565.4456
====> Test set loss: 575.4749
====> Epoch: 153 Average loss: 557.6523
====> Test set loss: 561.4644
====> Epoch: 154 Average loss: 549.6689
====> Test set loss: 569.5130
====> Epoch: 155 Average loss: 556.0486
====> Test set loss: 576.6522
====> Epoch: 156 Average loss: 542.9163
====> Test set loss:

In [None]:
for epoch in range(1, epochs + 1):
    with torch.no_grad():
        sample = torch.randn(64, 64).to(device)
        sample = model.sampler(sample).cpu()
        save_image(sample.view(64, 3, 121, 121),
                   'results/img_gen_sample_' + str(epoch) + '.png')