<a href="https://colab.research.google.com/github/axtonisaly1013/GAN_VAE_Comparison/blob/main/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
cd drive/My Drive/Colab Notebooks/Machine Learning/Plots/VAE_Training&Validation

/content/drive/My Drive/Colab Notebooks/Machine Learning/Plots/VAE_Training&Validation


In [None]:
# Based on: https://debuggercafe.com/getting-started-with-variational-autoencoder-using-pytorch/

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import numpy as np
import imageio

from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid

# learning parameters
batch_size = 512
epochs = 150
features = 2 #was 32
sample_size = 64
nz = 2 #was 20
lr = 0.0002
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# transforms
transform = transforms.Compose([
    transforms.ToTensor(),
])
to_pil_image = transforms.ToPILImage()

# datasets and loaders
train_data = datasets.MNIST(
    root='../../data',
    train=True,
    download=True,
    transform=transform
)
val_data = datasets.MNIST(
    root='../../data',
    train=False,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True
)
val_loader = DataLoader(
    val_data,
    batch_size=batch_size,
    shuffle=False
)

# class Encoder(nn.Module):
#   def __init__(self, features):
#     super(Encoder, self).__init__()
#     self.features = features
#     self.encoder = nn.Sequential(
#         nn.Linear(784,512),
#         nn.LeakyReLU(0.2),
#         nn.Linear(512, 256),
#         nn.LeakyReLU(0.2),
#         nn.Linear(256,self.features*2),        
#     )

#   def reparameterize(self, mu, log_var):
#     # mu: mean from encoder's latent space
#     # log_var: log variance from encoder's latent space

#     std = torch.exp(0.5*log_var)
#     eps = torch.randn_like(std)
#     sample = mu + (eps * std)
#     return sample

#   def forward(self, x):
#     x = self.encoder(x)
#     x = x.view(-1, 2, self.features)

#     mu = x[:, 0, :]
#     log_var = x[:, 1, :]

#     z = self.reparameterize(mu, log_var)

#     return z, mu, log_var
# convolutional = 0
# class Decoder(nn.Module):
#   def __init__(self, features):
#     super(Decoder, self).__init__()
#     self.features = features
#     self.decoder1 = nn.Sequential(
#         nn.Linear(self.features, 512), #half the in features?
#         nn.LeakyReLU(0.2),
#         nn.Linear(512,1024),
#         nn.LeakyReLU(0.2),
#     )
#     self.decoder2 = nn.Sequential(
#         nn.Linear(1024,784),
#         nn.Sigmoid(),
#     )

#   def forward(self, z):
#     x = self.decoder1(z)
#     reconstruction = self.decoder2(x)
#     return reconstruction

## Convolutional Configuration (code: https://github.com/coolvision/vae_conv/blob/master/vae_conv_model_mnist.py)
class Encoder(nn.Module):
  def __init__(self, features):
    super(Encoder, self).__init__()
    self.features = features
    self.encoder = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1, bias=False),
        nn.LeakyReLU(0.2),
        nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),
        nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),
        nn.Conv2d(256, 1024, kernel_size=4, stride=1, padding=0, bias=False),
        nn.LeakyReLU(0.2),
    )
    self.fc1 = nn.Linear(1024, 512)
    self.fc21 = nn.Linear(512, features)
    self.fc22 = nn.Linear(512, features)

  def reparameterize(self, mu, log_var):
    # mu: mean from encoder's latent space
    # log_var: log variance from encoder's latent space

    std = torch.exp(0.5*log_var)
    eps = torch.randn_like(std)
    sample = mu + (eps * std)
    return sample

  def forward(self, x):
    conv = self.encoder(x)
    h1 = self.fc1(conv.view(-1,1024))
    mu = self.fc21(h1)
    log_var = self.fc22(h1)

    z = self.reparameterize(mu, log_var)

    return z, mu, log_var
convolutional = 1
class Decoder(nn.Module):
  def __init__(self, features):
    super(Decoder, self).__init__()
    self.features = features
    self.decoder = nn.Sequential(
        nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(512), 
        nn.LeakyReLU(0.2),
        nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, bias=False), 
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),
        nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(128), 
        nn.LeakyReLU(0.2),
        nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1, bias=False), 
        nn.Sigmoid(),
    )
    self.fc = nn.Sequential(
        nn.Linear(features,512),
        nn.LeakyReLU(0.2),
        nn.Linear(512,1024),
    )

  def forward(self, z):
    x = self.fc(z)
    x = x.view(-1, 1024, 1, 1)
    reconstruction = self.decoder(x)
    return reconstruction

#######################################################
encoder = Encoder(features).to(device)
decoder = Decoder(features).to(device)
optim_e = optim.Adam(encoder.parameters(), lr=lr)
optim_d = optim.Adam(decoder.parameters(), lr=lr)

if not convolutional:
  decoder.load_state_dict(torch.load('models/decoder.pth', map_location=device))
  encoder.load_state_dict(torch.load('models/encoder.pth', map_location=device))

criterion = nn.BCELoss(reduction='sum')

images = [] # to store images generatd by the generator

def loss_function(bce_loss, mu, log_var):
  kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
  return kld_loss + bce_loss

# function to create the noise vector
def create_noise(sample_size, features):
  return torch.randn(sample_size, features).to(device)

def save_generator_image(image, path):
  save_image(image, path)   

def train_VAE(dataloader):
  encoder.train()
  decoder.train()

  running_loss = 0.0
  for bi, data in enumerate(train_loader):
    data, _ = data
    data = data.to(device)
    if not convolutional:
      data = data.view(data.size(0),-1)

    optim_e.zero_grad()
    optim_d.zero_grad()

    z, mu, log_var = encoder(data)
    reconstruction = decoder(z)

    bce_loss = criterion(reconstruction, data)
    loss = loss_function(bce_loss, mu, log_var)
    running_loss += loss.item()
    loss.backward()
    optim_e.step()
    optim_d.step()

  train_loss = running_loss/len(dataloader.dataset)
  return train_loss

# Validate reconstruction performance
def validate_VAE(dataloader):
  encoder.eval()
  decoder.eval()
  running_loss = 0.0
  with torch.no_grad():
    for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)):
      data, _ = data
      data = data.to(device)

      z, mu, log_var = encoder(data)
      reconstruction = decoder(z)

      bce_loss = criterion(reconstruction, data)
      loss = loss_function(bce_loss, mu, log_var)
      running_loss += loss.item()

      # # save images in grid
      # generated_img = reconstruction.cpu().detach()
      # generated_img = make_grid(generated_img)
      # save_image(generated_img, f'output/gen_img{epoch}.png')

      num_rows = 8
      if i == int(len(val_data)/dataloader.batch_size) - 1:
        both = torch.cat((data.view(batch_size, 1, 28, 28)[:8],
                          reconstruction.view(batch_size, 1, 28, 28)[:8]))
        save_image(both.cpu(), f'output/output{epoch}.png', nrow=num_rows)

  val_loss = running_loss/len(dataloader.dataset)
  return val_loss

# Training Loop
train_loss = []
val_loss = []
noise = create_noise(sample_size, features)

print("On epoch: ")
for epoch in range(epochs):
    print(f"{epoch+1} ")
    train_epoch_loss = train_VAE(train_loader)
    # val_epoch_loss = validate_VAE(val_loader)

    #generate grid of images for validation
    generated_img = decoder(noise).cpu().detach()
    generated_img = generated_img.view(-1,1,28,28)
    generated_img = make_grid(generated_img)
    if convolutional:
      save_generator_image(generated_img, f'output_conv/gen_img{epoch}.png')
    else:
      save_generator_image(generated_img, f'output/gen_img{epoch}.png')
    images.append(generated_img)


    train_loss.append(train_epoch_loss)
    # val_loss.append(val_epoch_loss)

# save models
if convolutional:
  torch.save(encoder.state_dict(), 'models/encoder_conv.pth')
  torch.save(decoder.state_dict(), 'models/decoder_conv.pth')
  torch.save(encoder, 'models/encoder_model_conv.pth')
  torch.save(decoder, 'models/decoder_model_conv.pth')
else:
  torch.save(encoder.state_dict(), 'models/encoder.pth')
  torch.save(decoder.state_dict(), 'models/decoder.pth')
  torch.save(encoder, 'models/encoder_model.pth')
  torch.save(decoder, 'models/decoder_model.pth')

# save the generated images as GIF file
imgs = [np.array(to_pil_image(img)) for img in images]
if convolutional:
  imageio.mimsave('output_conv/generator_images.gif', imgs)
else:
  imageio.mimsave('output/generator_images.gif', imgs)

On epoch: 
1 
2 
3 
4 
5 
6 
7 
8 
9 
10 
11 
12 
13 
14 
15 
16 
17 
18 
19 
20 
21 
22 
23 
24 
25 
26 
27 
28 
29 
30 
31 
32 
33 
34 
35 
36 
37 
38 
39 
40 
41 
42 
43 
44 
45 
46 
47 
48 
49 
50 
51 
52 
53 
54 
55 
56 
57 
58 
59 
60 
61 
62 
63 
64 
65 
66 
67 
68 
69 
70 
71 
72 
73 
74 
75 
76 
77 
78 
79 
80 
81 
82 
83 
84 
85 
86 
87 
88 
89 
90 
91 
92 
93 
94 
95 
96 
97 
98 
99 
100 
101 
102 
103 
104 
105 
106 
107 
108 
109 
110 
111 
112 
113 
114 
115 
116 
117 
118 
119 
120 
121 
122 
123 
124 
125 
126 
127 
128 
129 
130 
131 
132 
133 
134 
135 
136 
137 
138 
139 
140 
141 
142 
143 
144 
145 
146 
147 
148 
149 
150 


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import numpy as np
import imageio

from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid

# learning parameters
batch_size = 512
epochs = 100
features = 16
sample_size = 64
nz = 20
lr = 0.0002
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# transforms
transform = transforms.Compose([
    transforms.ToTensor(),
])
to_pil_image = transforms.ToPILImage()

# datasets and loaders
train_data = datasets.MNIST(
    root='../../data',
    train=True,
    download=True,
    transform=transform
)
val_data = datasets.MNIST(
    root='../../data',
    train=False,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True
)
val_loader = DataLoader(
    val_data,
    batch_size=batch_size,
    shuffle=False
)

# class Encoder(nn.Module):
#   def __init__(self, features):
#     super(Encoder, self).__init__()
#     self.features = features
#     self.encoder = nn.Sequential(
#         nn.Linear(784,512),
#         nn.LeakyReLU(0.2),
#         nn.Linear(512,self.features*2),        
#     )

#   def reparameterize(self, mu, log_var):
#     # mu: mean from encoder's latent space
#     # log_var: log variance from encoder's latent space

#     std = torch.exp(0.5*log_var)
#     eps = torch.randn_like(std)
#     sample = mu + (eps * std)
#     return sample

#   def forward(self, x):
#     x = self.encoder(x).view(-1, 2, self.features)

#     mu = x[:, 0, :]
#     log_var = x[:, 1, :]

#     z = self.reparameterize(mu, log_var)

#     return z, mu, log_var
# convolutional = 0
# class Decoder(nn.Module):
#   def __init__(self, features):
#     super(Decoder, self).__init__()
#     self.features = features
#     self.decoder1 = nn.Sequential(
#         nn.Linear(self.features, 512), #half the in features?
#         nn.LeakyReLU(0.2),
#     )
#     self.decoder2 = nn.Sequential(
#         nn.Linear(512,784),
#         nn.Sigmoid(),
#     )

#   def forward(self, z):
#     x = self.decoder1(z)
#     reconstruction = self.decoder2(x)
#     return reconstruction

ngf = 64
ndf = 64
nc = 1

class VAE(nn.Module):
    def __init__(self, nz):
        super(VAE, self).__init__()

        self.nz = nz

        self.encoder = nn.Sequential(
            # input is (nc) x 28 x 28
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 14 x 14
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 7 x 7
            nn.Conv2d(ndf * 2, ndf * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 4 x 4
            nn.Conv2d(ndf * 4, 1024, 4, 1, 0, bias=False),
            # nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            # nn.Sigmoid()
        )

        self.decoder = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     1024, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     nc, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(ngf),
            # nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            # nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            # nn.Tanh()
            nn.Sigmoid()
            # state size. (nc) x 64 x 64
        )

        self.fc1 = nn.Linear(1024, 512)
        self.fc21 = nn.Linear(512, nz)
        self.fc22 = nn.Linear(512, nz)

        self.fc3 = nn.Linear(nz, 512)
        self.fc4 = nn.Linear(512, 1024)

        self.lrelu = nn.LeakyReLU()
        self.relu = nn.ReLU()
        # self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        conv = self.encoder(x);
        h1 = self.fc1(conv.view(-1, 1024))
        return self.fc21(h1), self.fc22(h1)

    def decode(self, z):
        h3 = self.relu(self.fc3(z))
        deconv_input = self.fc4(h3)
        deconv_input = deconv_input.view(-1,1024,1,1)
        return self.decoder(deconv_input)

    def reparameterize(self, mu, log_var):
    # mu: mean from encoder's latent space
    # log_var: log variance from encoder's latent space
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        sample = mu + (eps * std)
        return sample

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        decoded = self.decode(z)
        return decoded, mu, logvar
convolutional = 1
## Convolutional Configuration (code: https://github.com/coolvision/vae_conv/blob/master/vae_conv_model_mnist.py)
# class Encoder(nn.Module):
#   def __init__(self, features):
#     super(Encoder, self).__init__()
#     self.features = features
#     self.encoder = nn.Sequential(
#         nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
#         nn.LeakyReLU(0.2),
#         nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
#         nn.BatchNorm2d(128),
#         nn.LeakyReLU(0.2),
#         nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
#         nn.BatchNorm2d(256),
#         nn.LeakyReLU(0.2),
#         nn.Conv2d(256, 1024, kernel_size=4, stride=1, padding=1),
#         nn.LeakyReLU(0.2),
#     )
#     self.fc1 = nn.Linear(1024, 512)
#     self.fc21 = nn.Linear(512, 32)
#     self.fc22 = nn.Linear(512, 32)

#   def reparameterize(self, mu, log_var):
#     # mu: mean from encoder's latent space
#     # log_var: log variance from encoder's latent space

#     std = torch.exp(0.5*log_var)
#     eps = torch.randn_like(std)
#     sample = mu + (eps * std)
#     return sample

#   def forward(self, x):
#     print(x.size())
#     conv = self.encoder(x)
#     print(conv.size())
#     h1 = self.fc1(conv.view(-1,1024))
#     print(h1.size())
#     mu = self.fc21(h1)
#     log_var = self.fc22(h1)

#     z = self.reparameterize(mu, log_var)
#     print("z: ")
#     print(z.size())

#     return z, mu, log_var
# convolutional = 1
# class Decoder(nn.Module):
#   def __init__(self, features):
#     super(Decoder, self).__init__()
#     self.features = features
#     self.decoder = nn.Sequential(
#         nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=1, padding=0),
#         nn.BatchNorm2d(512), 
#         nn.LeakyReLU(0.2),
#         nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1), 
#         nn.BatchNorm2d(512),
#         nn.LeakyReLU(0.2),
#         nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
#         nn.BatchNorm2d(128), 
#         nn.LeakyReLU(0.2),
#         nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1), 
#         nn.Sigmoid(),
#     )
#     self.fc = nn.Linear(32,1024)

#   def forward(self, z):
#     print(z.size())
#     x = self.fc(z)
#     print(x.size())
#     x = x.view(-1, 1024, 1, 1)
#     print(x.size())
#     reconstruction = self.decoder(x)
#     print(reconstruction.size())
#     return reconstruction

#######################################################
# encoder = Encoder(features).to(device)
# decoder = Decoder(features).to(device)
# optim_e = optim.Adam(encoder.parameters(), lr=lr)
# optim_d = optim.Adam(decoder.parameters(), lr=lr)

vae = VAE(32).to(device)
optim = optim.Adam(vae.parameters(), lr=0.002)

criterion = nn.BCELoss(reduction='sum')

images = [] # to store images generatd by the generator

def loss_function(bce_loss, mu, log_var):
  kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
  return kld_loss + bce_loss

# function to create the noise vector
def create_noise(sample_size, nz):
  return torch.randn(sample_size, nz).to(device)

def save_generator_image(image, path):
  save_image(image, path)   

def train_VAE(dataloader):
  # encoder.train()
  # decoder.train()
  vae.train()
  running_loss = 0.0
  for bi, data in enumerate(train_loader):
    data, _ = data
    data = data.to(device)
    # data = data.view(data.size(0),-1)

    # optim_e.zero_grad()
    # optim_d.zero_grad()
    optim.zero_grad()

    # z, mu, log_var = encoder(data)
    # reconstruction = decoder(z)
    reconstruction, mu, log_var = vae(data)
    print(data.size())
    print(reconstruction.size())

    bce_loss = criterion(reconstruction, data)
    loss = loss_function(bce_loss, mu, log_var)
    running_loss += loss.item()
    loss.backward()
    # optim_e.step()
    # optim_d.step()
    optim.step()

  train_loss = running_loss/len(dataloader.dataset)
  return train_loss

# Validate reconstruction performance
def validate_VAE(dataloader):
  encoder.eval()
  decoder.eval()
  running_loss = 0.0
  with torch.no_grad():
    for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)):
      data, _ = data
      data = data.to(device)

      # z, mu, log_var = encoder(data)
      # reconstruction = decoder(z)
      reconstruction, mu, log_var = vae(data)

      bce_loss = criterion(reconstruction, data)
      loss = loss_function(bce_loss, mu, log_var)
      running_loss += loss.item()

      # # save images in grid
      # generated_img = reconstruction.cpu().detach()
      # generated_img = make_grid(generated_img)
      # save_image(generated_img, f'output/gen_img{epoch}.png')

      num_rows = 8
      if i == int(len(val_data)/dataloader.batch_size) - 1:
        both = torch.cat((data.view(batch_size, 1, 28, 28)[:8],
                          reconstruction.view(batch_size, 1, 28, 28)[:8]))
        save_image(both.cpu(), f'output/output{epoch}.png', nrow=num_rows)

  val_loss = running_loss/len(dataloader.dataset)
  return val_loss

# Training Loop
train_loss = []
val_loss = []
noise = create_noise(sample_size, features)

print("On epoch: ")
for epoch in range(epochs):
    print(f"{epoch+1} ")
    train_epoch_loss = train_VAE(train_loader)
    # val_epoch_loss = validate_VAE(val_loader)

    #generate grid of images for validation
    # generated_img = decoder(noise).cpu().detach()
    # generated_img = generated_img.view(-1,1,28,28)
    # generated_img = make_grid(generated_img)
    # if convolutional:
    #   save_generator_image(generated_img, f'output_conv/gen_img{epoch}.png')
    # else:
    #   save_generator_image(generated_img, f'output/gen_img{epoch}.png')
    # images.append(generated_img)


    train_loss.append(train_epoch_loss)
    # val_loss.append(val_epoch_loss)

# save models
# if convolutional:
#   torch.save(encoder.state_dict(), 'models/encoder_conv.pth')
#   torch.save(decoder.state_dict(), 'models/decoder_conv.pth')
#   torch.save(encoder, 'models/encoder_model_conv.pth')
#   torch.save(decoder, 'models/decoder_model_conv.pth')
# else:
#   torch.save(encoder.state_dict(), 'models/encoder.pth')
#   torch.save(decoder.state_dict(), 'models/decoder.pth')
#   torch.save(encoder, 'models/encoder_model.pth')
#   torch.save(decoder, 'models/decoder_model.pth')

# # save the generated images as GIF file
# imgs = [np.array(to_pil_image(img)) for img in images]
# if convolutional:
#   imageio.mimsave('output_conv/generator_images.gif', imgs)
# else:
#   imageio.mimsave('output/generator_images.gif', imgs)

On epoch: 
1 
torch.Size([512, 1, 28, 28])
torch.Size([512, 1, 28, 28])
torch.Size([512, 1, 28, 28])
torch.Size([512, 1, 28, 28])
torch.Size([512, 1, 28, 28])
torch.Size([512, 1, 28, 28])
torch.Size([512, 1, 28, 28])
torch.Size([512, 1, 28, 28])
torch.Size([512, 1, 28, 28])
torch.Size([512, 1, 28, 28])
torch.Size([512, 1, 28, 28])
torch.Size([512, 1, 28, 28])


KeyboardInterrupt: ignored