In [1]:
!pip install torch==1.11.0 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/, https://download.pytorch.org/whl/cu113


In [2]:
import random
import os
from typing import Sequence, Tuple

import numpy as np
import torch as t
import torch.utils.data
import tqdm
import math
import torch.nn.functional as F
from torch import nn
from torch.nn import DataParallel
from torchvision import transforms, datasets
from torch.autograd import Variable
from torch import optim
from torch.distributions import Normal, Distribution, kl_divergence
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


In [3]:
class IterableWrapper(IterableDataset):
    """
    Turns a Dataset into an IterableDataset by endlessly yielding random samples from the dataset.
    """

    def __init__(self, delegate: Dataset):
        self.delegate = delegate

    def __iter__(self):
        l = len(self.delegate)
        while True:
            for idx in random.sample(range(l), l):
                yield self.delegate[idx]

In [4]:
def iwae(x: t.Tensor, p_x_given_z: Distribution, q_z_given_x: Distribution, p_z: Distribution, z: t.Tensor):
    """
        log(p(x)) >= logsumexp_{i=1}^N[ log(p(x|z_i)) + log(p(z_i)) - log(q(z_i|x))] - log(N)
        x: bchw
        q_z_given_x: bz
        z: bnz
        p_x_given_z: (bn)chw
    """
    b, c, h, w = x.shape
    b, n, zs = z.shape

    x = (x.unsqueeze(1)
         .expand((-1, n, -1, -1, -1))
         .reshape(b * n, c, h, w)
         )  # (bn)chw
    logpx_given_z = p_x_given_z.log_prob(x).sum(dim=(1, 2, 3)).reshape((b, n))
    logpz = p_z.log_prob(z).sum(dim=2)
    logqz_given_x = q_z_given_x.log_prob(z.permute((1, 0, 2))).sum(dim=2).permute((1, 0))
    logpx = (t.logsumexp(logpx_given_z + logpz - logqz_given_x, dim=1) - t.log(t.scalar_tensor(z.shape[1])))

    return -logpx, None, None

def elbo(x: t.Tensor, p_x_given_z: Distribution, q_z_given_x: Distribution, p_z: Distribution, z: t.Tensor):
  """
      log p(x) >= E_q(z|x) [ log p(x|z) p(z) / q(z|x) ]
      Reconstruction + KL divergence losses summed over all elements and batch
      x: bchw
      q_z_given_x: bz
      z: bnz
      p_x_given_z: (bn)chw
  """

  b, c, h, w = x.shape
  b, n, zs = z.shape

  x = (x.unsqueeze(1)
      .expand((-1, n, -1, -1, -1))
      .reshape(-1, c, h, w)
      )  # (bn)chw

  logpx_given_z = p_x_given_z.log_prob(x).sum(dim=(1, 2, 3)).reshape((b, n)).mean(dim=1)
  kld = kl_divergence(q_z_given_x, p_z).sum(dim=1)

  reconstruction_loss = -logpx_given_z
  kl_loss = kld

  loss = reconstruction_loss + kl_loss
  return loss, reconstruction_loss, kl_loss

In [5]:
revision = os.environ.get("REVISION") or "%s" % datetime.now()
message = os.environ.get('MESSAGE')
tensorboard_dir = 'tensorboard'
flush_secs = 10

def get_writers(name):
    train_writer = SummaryWriter(tensorboard_dir + 'train', flush_secs=flush_secs)
    test_writer = SummaryWriter(tensorboard_dir + 'test', flush_secs=flush_secs)
    return train_writer, test_writer

In [6]:
class Model(t.nn.Module):

  def train_batch(self) -> float:
    raise NotImplemented()

  def eval_batch(self) -> float:
    raise NotImplemented()

  def save(self, fn):
    t.save({
      'batch_idx': self.batch_idx,
      'model_state_dict': self.state_dict(),
      'optimizer_state_dict': self.optimizer.state_dict(),
    }, fn)

  def load(self, fn):
    checkpoint = t.load(fn, map_location=t.device(self.device))
    self.batch_idx = checkpoint["batch_idx"]
    self.load_state_dict(checkpoint["model_state_dict"])
    self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

In [7]:
def train(model: Model, n_updates=int(1e6), eval_interval=1000):
  best = float("inf")
  for i in tqdm.tqdm(range(n_updates)):
    model.train_batch()
    if (i + 1) % eval_interval == 0:
      loss = model.eval_batch()
      model.save("latest")
      if loss < best:
        best = loss
        model.save("best")

In [8]:
class DiscretizedMixtureLogitsDistribution(Distribution):
    def __init__(self, nr_mix, logits):
        super().__init__()
        self.logits = logits
        self.nr_mix = nr_mix
        self._batch_shape = logits.shape

    def log_prob(self, value):
        return - discretized_mix_logistic_loss(value * 2 - 1, self.logits).unsqueeze(1)  # add channel dim for compatibility with loss functions expecting bchw

    def sample(self):
        return (sample_from_discretized_mix_logistic(self.logits, self.nr_mix) + 1) / 2

    @property
    def mean(self):
        """
        Returns the mean of the distribution.
        """
        return t.stack([self.sample() for _ in range(100)]).mean(dim=0)

class DiscretizedMixtureLogits():

    def __init__(self, nr_mix, **kwargs):
        self.nr_mix = nr_mix

    def __call__(self, logits):
        return DiscretizedMixtureLogitsDistribution(self.nr_mix, logits)

def log_sum_exp(x):
    """ numerically stable log_sum_exp implementation that prevents overflow """
    # TF ordering
    axis = len(x.size()) - 1
    m, _ = torch.max(x, dim=axis)
    m2, _ = torch.max(x, dim=axis, keepdim=True)
    return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))

def log_prob_from_logits(x):
    """ numerically stable log_softmax implementation that prevents overflow """
    # TF ordering
    axis = len(x.size()) - 1
    m, _ = torch.max(x, dim=axis, keepdim=True)
    return x - m - torch.log(torch.sum(torch.exp(x - m), dim=axis, keepdim=True))


def discretized_mix_logistic_loss(x, l):
    """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
    # Pytorch ordering
    x = x.permute(0, 2, 3, 1)
    l = l.permute(0, 2, 3, 1)
    xs = [int(y) for y in x.size()]
    ls = [int(y) for y in l.size()]

    # here and below: unpacking the params of the mixture of logistics
    nr_mix = int(ls[-1] / 10)
    logit_probs = l[:, :, :, :nr_mix]
    l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3])  # 3 for mean, scale, coef
    means = l[:, :, :, :, :nr_mix]
    # log_scales = torch.max(l[:, :, :, :, nr_mix:2 * nr_mix], -7.)
    log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.)

    coeffs = torch.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix])
    # here and below: getting the means and adjusting them based on preceding
    # sub-pixels
    x = x.contiguous()
    x = x.unsqueeze(-1) + torch.zeros(xs + [nr_mix], device=x.device)
    m2 = (means[:, :, :, 1, :] + coeffs[:, :, :, 0, :]
          * x[:, :, :, 0, :]).view(xs[0], xs[1], xs[2], 1, nr_mix)

    m3 = (means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] +
          coeffs[:, :, :, 2, :] * x[:, :, :, 1, :]).view(xs[0], xs[1], xs[2], 1, nr_mix)

    means = torch.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3)
    centered_x = x - means
    inv_stdv = torch.exp(-log_scales)
    plus_in = inv_stdv * (centered_x + 1. / 255.)
    cdf_plus = torch.sigmoid(plus_in)
    min_in = inv_stdv * (centered_x - 1. / 255.)
    cdf_min = torch.sigmoid(min_in)
    # log probability for edge case of 0 (before scaling)
    log_cdf_plus = plus_in - F.softplus(plus_in)
    # log probability for edge case of 255 (before scaling)
    log_one_minus_cdf_min = -F.softplus(min_in)
    cdf_delta = cdf_plus - cdf_min  # probability for all other cases
    mid_in = inv_stdv * centered_x
    # log probability in the center of the bin, to be used in extreme cases
    # (not actually used in our code)
    log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)

    # now select the right output: left edge case, right edge case, normal
    # case, extremely low prob case (doesn't actually happen for us)

    # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select()
    # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta)))

    # robust version, that still works if probabilities are below 1e-5 (which never happens in our code)
    # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs
    # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue
    # if the probability on a sub-pixel is below 1e-5, we use an approximation
    # based on the assumption that the log-density is constant in the bin of
    # the observed sub-pixel value

    inner_inner_cond = (cdf_delta > 1e-5).float()
    inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * (
            log_pdf_mid - np.log(127.5))
    inner_cond = (x > 0.999).float()
    inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
    cond = (x < -0.999).float()
    log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
    log_probs = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs)

    return - log_sum_exp(log_probs)

def to_one_hot(tensor, n, fill_with=1.):
    # we perform one hot encore with respect to the last axis
    one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
    if tensor.is_cuda: one_hot = one_hot.cuda()
    one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
    return Variable(one_hot)


def sample_from_discretized_mix_logistic(l, nr_mix):
    # Pytorch ordering
    l = l.permute(0, 2, 3, 1)
    ls = [int(y) for y in l.size()]
    xs = ls[:-1] + [3]

    # unpack parameters
    logit_probs = l[:, :, :, :nr_mix]
    l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3])
    # sample mixture indicator from softmax
    temp = torch.FloatTensor(logit_probs.size())
    if l.is_cuda: temp = temp.cuda()
    temp.uniform_(1e-5, 1. - 1e-5)
    temp = logit_probs.data - torch.log(- torch.log(temp))
    _, argmax = temp.max(dim=3)

    one_hot = to_one_hot(argmax, nr_mix)
    sel = one_hot.view(xs[:-1] + [1, nr_mix])
    # select logistic parameters
    means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4)
    log_scales = torch.clamp(torch.sum(
        l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.)
    coeffs = torch.sum(torch.tanh(
        l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) * sel, dim=4)
    # sample from logistic & clip to interval
    # we don't actually round to the nearest 8bit value when sampling
    u = torch.FloatTensor(means.size())
    if l.is_cuda: u = u.cuda()
    u.uniform_(1e-5, 1. - 1e-5)
    u = Variable(u)
    x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
    x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.)
    x1 = torch.clamp(torch.clamp(
        x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, min=-1.), max=1.)
    x2 = torch.clamp(torch.clamp(
        x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1, min=-1.), max=1.)

    out = torch.cat([x0.view(xs[:-1] + [1]), x1.view(xs[:-1] + [1]), x2.view(xs[:-1] + [1])], dim=3)
    # put back in Pytorch ordering
    out = out.permute(0, 3, 1, 2)
    return out

In [9]:
n_mixtures = 1

def state_to_dist(state):
    return DiscretizedMixtureLogitsDistribution(n_mixtures, state[:, :n_mixtures * 10, :, :])

In [10]:
class VAE(Model):
    
  def __init__(self,
                 h: int,
                 w: int,
                 n_channels: int,
                 z_size: int,
                 encoder: t.nn.Module,
                 decoder_linear: t.nn.Module, 
                 decoder: t.nn.Module,
                 train_data: Dataset,
                 val_data: Dataset,
                 test_data: Dataset,
                 states_to_dist,
                 batch_size: int,
                 p_update: float,
                 min_steps: int,
                 max_steps: int, 
                 encoder_hid
                 ):
    super(Model, self).__init__()
    self.h = h # height of the image
    self.w = w # width of the image
    self.n_channels = n_channels # number of channels of the image
    self.state_to_dist = states_to_dist # function that turns a set of state to a distribution
    self.z_size = z_size # dimensionality of the latent space
    self.device = "cuda" if t.cuda.is_available() else "cpu" # check if we have a gpu

    self.encoder = encoder # define the encoder
    self.decoder_linear = decoder_linear # define the linear decoder
    self.decoder = decoder # define the decoder
    self.unflatten = nn.Unflatten(-1, (encoder_hid * 2 ** 5, h // 16, w // 16))
    self.p_z = Normal(t.zeros(self.z_size, device=self.device), t.ones(self.z_size, device=self.device)) # defines a 0 mean prior distribution for the latent space

    self.test_set = test_data # appoint the test data
    self.train_loader = iter(DataLoader(IterableWrapper(train_data), batch_size=batch_size, pin_memory=True)) # initialize a data loader for the training data
    self.val_loader = iter(DataLoader(IterableWrapper(val_data), batch_size=batch_size, pin_memory=True)) # initialize a data loader for the validation data
    self.train_writer, self.test_writer = get_writers("vae") # initialize a writer for the tensorboard

    print(self) # report the model
    total = sum(p.numel() for p in self.parameters()) # calculate the total number of learnable parameters
    for n, p in self.named_parameters():
        print(n, p.shape, p.numel(), "%.1f" % (p.numel() / total * 100)) # report information about the layers of the encoder and the decoder
    print("Total: %d" % total) # print the total number of learnable parameters

    self.to(self.device) # move the pytorch tensor to the gpu if possible
    self.optimizer = optim.Adam(self.parameters(), lr=1e-4) # initialize the ADAM optimizer
    self.batch_idx = 0 # initalize the batch index to 0

  def train_batch(self):
    self.train(True) # set the training mode to True

    self.optimizer.zero_grad() # remove prior gradients from the mmodel
    x, y = next(self.train_loader) # load a batch of training data
    loss, z, p_x_given_z, recon_loss, kl_loss, state = self.forward(x, 1, elbo) # forward the batch of training data throught the network
    loss.mean().backward() # gradient of the loss with respect to all the model parameters

    t.nn.utils.clip_grad_norm_(self.parameters(), 1.0, error_if_nonfinite=True) # clip the gradient to a range of [-1, 1] 

    self.optimizer.step() # use the clip gradient to perform a step of backpropagation on the parameters of the model

    if self.batch_idx % 100 == 0: 
        self.report(self.train_writer, state, loss, recon_loss, kl_loss) # report on the results every 100 steps

    self.batch_idx += 1 # increment the batch index
    return loss.mean().item() 
    
  def eval_batch(self):
    self.train(False) # set the training mode to False
    with t.no_grad():
        x, y = next(self.val_loader) # load a batch of validation data
        loss, z, p_x_given_z, recon_loss, kl_loss, state = self.forward(x, 1, iwae) # forward the batch of validation data throught the network
        self.report(self.test_writer, state, loss, recon_loss, kl_loss) # report on the results
    return loss.mean().item()

  def test(self, n_iw_samples):
      self.train(False) # set the training mode to False
      with t.no_grad():
          total_loss = 0.0 # initialize the total loss
          for x, y in tqdm.tqdm(self.test_set): # iterate over the whole test set 
              loss, z, p_x_given_z, recon_loss, kl_loss, states = self.forward(x, n_iw_samples, iwae) # forward a single sample of the testing data through the network 
              total_loss += loss.mean().item() # add the mean of the loss to the total loss

      print(total_loss / len(self.test_set)) # return the average loss of the test set

  def encode(self, x) -> Distribution:  # q(z|x)
      q = self.encoder(x) # run the encoder and retunrs a vector of size 2 * the latent space 
      loc = q[:, :self.z_size] # mean of the latent distribution
      logsigma = q[:, self.z_size:] # log variance of the latent ditribution
      return Normal(loc=loc, scale=t.exp(logsigma)) # return a normal distribution with the mean and variance received from the encoder

  def decode(self, z: t.Tensor) -> Tuple[Distribution, Sequence[t.Tensor]]:  # p(x|z)
      flat_features = self.decoder_linear(z)
      flat_features = t.squeeze(flat_features)
      unflattened = self.unflatten(flat_features)
      return self.decoder(unflattened) # run the decoder

  def forward(self, x, n_samples, loss_fn):
      x = x.to(self.device) # move the pytorch tensor to the gpu if possible

      q_z_given_x = self.encode(x) # run the encoder to receive the conditional latent distribution
      z = q_z_given_x.rsample((n_samples,)).permute((1, 0, 2)) # sample from the conditional latent distribution

      state = self.decode(z) # decode the sample
      print(state.shape)
      p_x_given_z = self.state_to_dist(state) # get the conditional probability distribution using the state

      loss, recon_loss, kl_loss = loss_fn(x, p_x_given_z, q_z_given_x, self.p_z, z) # calculate the loss using the two distributions

      return loss, z, p_x_given_z, recon_loss, kl_loss, state

  def report(self, writer: SummaryWriter, recon_state, loss, recon_loss, kl_loss):
      writer.add_scalar('loss', loss.mean().item(), self.batch_idx)
      writer.add_scalar('bpd', loss.mean().item() / (np.log(2) * self.n_channels * self.h * self.w), self.batch_idx)
      writer.add_scalar('pool_size', len(self.pool), self.batch_idx)

      if recon_loss is not None:
          writer.add_scalar('recon_loss', recon_loss.mean().item(), self.batch_idx)
      if kl_loss is not None:
          writer.add_scalar('kl_loss', kl_loss.mean().item(), self.batch_idx)

      with t.no_grad():
          # samples
          samples = self.p_z.sample((8,)).view(8, -1, 1, 1).expand(8, -1, self.h, self.w).to(self.device)
          states = self.decode(samples) # decode the samples into images
          writer.add_images("samples/samples", states, self.batch_idx)
  
          # Reconstructions
          writer.add_images("recons/samples", recon_state.detach(), self.batch_idx)

      writer.flush()

In [11]:
z_size = 256
vae_hid = 128
n_mixtures = 1
batch_size = 32
dmg_size = 16
p_update = 1.0
min_steps, max_steps = 64, 128

filter_size = 5
pad = filter_size // 2
encoder_hid = 32
h = w = 32
n_channels = 3

In [15]:
encoder = nn.Sequential(
    nn.Conv2d(n_channels, encoder_hid * 2 ** 0, filter_size, padding=pad), nn.ELU(),  # (bs, 32, h, w)
    nn.Conv2d(encoder_hid * 2 ** 0, encoder_hid * 2 ** 1, filter_size, padding=pad, stride=2), nn.ELU(),  # (bs, 64, h//2, w//2)
    nn.Conv2d(encoder_hid * 2 ** 1, encoder_hid * 2 ** 2, filter_size, padding=pad, stride=2), nn.ELU(),  # (bs, 128, h//4, w//4)
    nn.Conv2d(encoder_hid * 2 ** 2, encoder_hid * 2 ** 3, filter_size, padding=pad, stride=2), nn.ELU(),  # (bs, 256, h//8, w//8)
    nn.Conv2d(encoder_hid * 2 ** 3, encoder_hid * 2 ** 4, filter_size, padding=pad, stride=2), nn.ELU(),  # (bs, 512, h//16, w//16),
    nn.Flatten(),  # (bs, 512*h//16*w//16)
    nn.Linear(encoder_hid * (2 ** 4) * h // 16 * w // 16, 2 * z_size),
)

decoder_linear = nn.Sequential(
    nn.Linear(z_size, (encoder_hid * 2 ** 5) * 4), nn.ELU()
)

decoder = nn.Sequential(
    nn.ConvTranspose2d(encoder_hid * 2 ** 5, encoder_hid * 2 ** 4, filter_size, padding=pad, stride=2, output_padding=1), nn.ELU(),
    nn.ConvTranspose2d(encoder_hid * 2 ** 4, encoder_hid * 2 ** 3, filter_size, padding=pad, stride=2, output_padding=1), nn.ELU(),
    nn.ConvTranspose2d(encoder_hid * 2 ** 3, encoder_hid * 2 ** 2, filter_size, padding=pad, stride=2, output_padding=1), nn.ELU(),
    nn.ConvTranspose2d(encoder_hid * 2 ** 2, encoder_hid * 2 ** 1, filter_size, padding=pad, stride=2, output_padding=1), nn.ELU(),
    nn.ConvTranspose2d(encoder_hid * 2 ** 1, encoder_hid * 2 ** 0, filter_size, padding=pad, stride=2, output_padding=1), nn.ELU(),
    nn.ConvTranspose2d(encoder_hid * 2 ** 0, 10, filter_size, padding=pad)
)

In [13]:
import os
import torch
from skimage import io
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset

# Dataset can be found and downloaded at "https://www.kaggle.com/datasets/kvpratama/pokemon-images-dataset/download"

data_dir = os.path.join(os.getcwd(), "dataset")

def to_alpha(x):
  return torch.clip(x[3:4,...], 0.0, 1.0)

def to_rgb(x):
  # assume rgb premultiplied by alpha
  rgb, a = x[:3,...], to_alpha(x)
  return 1.0-a+rgb

class PokemonIMG(Dataset):

    def __init__(self):
        self.filenames = os.listdir(data_dir)
        self.h = self.w = 32
        self.transform = transforms.Compose([transforms.Resize((self.h, self.w)), transforms.ToTensor()])

    def __getitem__(self, index):
        img_name = os.path.join(data_dir,
                                self.filenames[index])
        image = self.transform(Image.fromarray(io.imread(img_name)))
        #train_set[0][0][:3,:,:] *= train_set[0][0][3:,:,:]
        image[:3,...] *= image[3:,...]
        return to_rgb(image), 0  # placeholder label

    def __len__(self):
        return len(self.filenames)


In [16]:
z_size = 256
n_mixtures = 1
batch_size = 32
p_update = 1.0
min_steps, max_steps = 64, 128

filter_size = 5
pad = filter_size // 2
encoder_hid = 32
h = w = 32
n_channels = 3


def state_to_dist(state):
    return DiscretizedMixtureLogitsDistribution(n_mixtures, state[:, :n_mixtures * 10, :, :])

dset = PokemonIMG()

num_samples = len(dset)
train_split = 0.7
val_split = 0.2
test_split = 0.1

num_train = math.floor(num_samples*train_split)
num_val = math.floor(num_samples*val_split)
num_test = math.floor(num_samples*test_split)
num_test = num_test + (num_samples - num_train - num_val - num_test)

train_set, val_set, test_set = t.utils.data.random_split(dset, [num_train, num_val, num_test])

vae = VAE(h, w, n_channels, z_size, encoder, decoder_linear, decoder, train_set, val_set, test_set, state_to_dist, batch_size, p_update, min_steps, max_steps, encoder_hid)
vae.eval_batch()
train(vae, n_updates=100_000, eval_interval=100)
vae.test(128)


VAE(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ELU(alpha=1.0)
    (2): Conv2d(32, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (3): ELU(alpha=1.0)
    (4): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (5): ELU(alpha=1.0)
    (6): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (7): ELU(alpha=1.0)
    (8): Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (9): ELU(alpha=1.0)
    (10): Flatten(start_dim=1, end_dim=-1)
    (11): Linear(in_features=2048, out_features=512, bias=True)
  )
  (decoder_linear): Sequential(
    (0): Linear(in_features=256, out_features=4096, bias=True)
    (1): ELU(alpha=1.0)
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(1024, 512, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
    (1): ELU(alpha=1.0)
    (2): ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=

  'with `validate_args=False` to turn off validation.')


RuntimeError: ignored