## Image Transformer

### Imports

In [23]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import logging
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import argparse
import sys
import time
import os
import logging
import yaml
import shutil
import numpy as np
import tensorboardX
import torch.optim as optim
import torchvision
from image_transformer import ImageTransformer
import matplotlib
import itertools
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchviz import make_dot
from tqdm import tqdm
import torch.nn as nn

### Defining the Model

ImageTransformer Class

In [24]:
NUM_PIXELS = 256

# Numerically stable implementations
def logsoftmax(x):
    m = torch.max(x, -1, keepdim=True).values
    return x - m - torch.log(torch.exp(x - m).sum(-1, keepdim=True))

def logsumexp(x):
    m = x.max(-1).values
    return m + torch.log(torch.exp(x - m[...,None]).sum(-1))

tb_logger = tensorboardX.SummaryWriter(log_dir=os.path.join('transformer_logs', args.doc))

class ImageTransformer(nn.Module):
    """ImageTransformer with DMOL or categorical distribution."""
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        #Decoder layers are stacked (nlayers = 12). ModuleList is a way to stack decoder blocks.
        self.layers = nn.ModuleList([DecoderLayer(hparams) for _ in range(hparams.nlayers)])

        #Dropout = 0.1. Dropout is a way of regularizing NNs by preventing coadaptation.
        self.input_dropout = nn.Dropout(p=hparams.dropout)

        '''if self.hparams.distr == "dmol": # Discretized mixture of logistic, for ordinal valued inputs
        assert self.hparams.channels == 3, "Only supports 3 channels for DML"
        size = (1, self.hparams.channels)
        self.embedding_conv = nn.Conv2d(1, self.hparams.hidden_size, kernel_size=size, stride=size)
        # 10 = 1 + 2c + c(c-1)/2; if only 1 channel, then 3 total
        depth = self.hparams.num_mixtures * 10
        self.output_dense = nn.Linear(self.hparams.hidden_size, depth, bias=False)'''

        #elif self.hparams.distr == "cat": #Categorical
        self.embeds = nn.Embedding(NUM_PIXELS * self.hparams.channels, self.hparams.hidden_size)
        self.output_dense = nn.Linear(self.hparams.hidden_size, NUM_PIXELS, bias=True)

        #I didn't know that Pixels had embeddings as well. 
        #nn.Linear is just a basic feedforward network. 
        
    def add_timing_signal(self, X, min_timescale=1.0, max_timescale=1.0e4):
        '''
        Yes, that weird Sin and Cos trick. 
        '''
        num_dims = len(X.shape) - 2 # 2 corresponds to batch and hidden_size dimensions
        num_timescales = self.hparams.hidden_size // (num_dims * 2)
        log_timescale_increment = np.log(max_timescale / min_timescale) / (num_timescales - 1)
        inv_timescales = min_timescale * torch.exp((torch.arange(num_timescales).float() * -log_timescale_increment))
        inv_timescales = inv_timescales.to(X.device)
        total_signal = torch.zeros_like(X) # Only for debugging purposes
        for dim in range(num_dims):
            length = X.shape[dim + 1] # add 1 to exclude batch dim
            position = torch.arange(length).float().to(X.device)
            scaled_time = position.view(-1, 1) * inv_timescales.view(1, -1)
            signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 1)
            prepad = dim * 2 * num_timescales
            postpad = self.hparams.hidden_size - (dim + 1) * 2 * num_timescales
            signal = F.pad(signal, (prepad, postpad))
            for _ in range(1 + dim):
                signal = signal.unsqueeze(0)
            for _ in range(num_dims - 1 - dim):
                signal = signal.unsqueeze(-2)
            X += signal
            total_signal += signal
        return X

    def shift_and_pad_(self, X):
        # Shift inputs over by 1 and pad
        shape = X.shape
        X = X.view(shape[0], shape[1] * shape[2], shape[3])
        X = X[:,:-1,:]
        X = F.pad(X, (0, 0, 1, 0)) # Pad second to last dimension
        X = X.view(shape)
        return X

    def forward(self, X, sampling=False):
        # Reshape inputs
        if sampling:
            curr_infer_length = X.shape[1]
            row_size = self.hparams.image_size * self.hparams.channels
            nrows = curr_infer_length // row_size + 1
            X = F.pad(X, (0, nrows * row_size - curr_infer_length))
            X = X.view(X.shape[0], -1, row_size)
        else:
            X = X.permute([0, 2, 3, 1]).contiguous()
            X = X.view(X.shape[0], X.shape[1], X.shape[2] * X.shape[3]) # Flatten channels into width

        '''# Inputs -> embeddings
        if self.hparams.distr == "dmol":
            # Create a "channel" dimension for the 1x3 convolution
            # (NOTE: can apply a 1x1 convolution and not reshape, this is for consistency)
            X = X.unsqueeze(1)
            X = F.relu(self.embedding_conv(X))
            X = X.permute([0, 2, 3, 1]) # move channels to the end'''
        
        #elif self.hparams.distr == "cat":
        # Convert to indexes, and use separate embeddings for different channels
        X = (X * (NUM_PIXELS - 1)).long()
        channel_addition = (torch.tensor([0, 1, 2]) * NUM_PIXELS).to(X.device).repeat(X.shape[2] // 3).view(1, 1, -1)
        X += channel_addition
        X = self.embeds(X) * (self.hparams.hidden_size ** 0.5)

        X = self.shift_and_pad_(X)
        X = self.add_timing_signal(X)
        shape = X.shape
        X = X.view(shape[0], -1, shape[3])

        X = self.input_dropout(X)
        for layer in self.layers:
            X = layer(X)
        X = self.layers[-1].preprocess_(X) # NOTE: this is identity (exists to replicate tensorflow code)
        X = self.output_dense(X).view(shape[:3] + (-1,))

        if not sampling and self.hparams.distr == "cat": # Unpack the channels
            X = X.view(X.shape[0], X.shape[1], X.shape[2] // self.hparams.channels, self.hparams.channels, X.shape[3])
            X = X.permute([0, 3, 1, 2, 4])

        return X

    #def split_to_dml_params(self, preds, targets=None, sampling=False):
    #DELETED since this function was only used for DMOL distribution

    # Modified from official PixCNN++ code
    #def dml_logp(self, logits, means, log_scales, targets):
    #DELETED since this function was only used for DMOL distribution

    # Assumes targets have been rescaled to [-1., 1.]
    def loss(self, preds, targets):
        '''if self.hparams.distr == "dmol":
            # Assumes 3 channels. Input: [batch_size, height, width, 10 * 10]
            logits, locs, log_scales = self.split_to_dml_params(preds, targets)
            targets = targets.permute([0, 2, 3, 1])
            log_probs = self.dml_logp(logits, locs, log_scales, targets)
            return -log_probs'''
        #elif self.hparams.distr == "cat":
        targets = (targets * (NUM_PIXELS - 1)).long()
        ce = F.cross_entropy(preds.permute(0, 4, 1, 2, 3), targets, reduction='none')
        return ce

    def accuracy(self, preds, targets):
        #for Categorical Distribution
        targets = (targets * (NUM_PIXELS - 1)).long()
        argmax_preds = torch.argmax(preds, dim=-1)
        acc = torch.eq(argmax_preds, targets).float().sum() / np.prod(argmax_preds.shape)
        return acc

    def sample_from_dmol(self, outputs):
        logits, locs, log_scales, coeffs = self.split_to_dml_params(outputs, sampling=True)
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) * (1. - 2 * 1e-5) + 1e-5))
        sel = torch.argmax(logits + gumbel_noise, -1, keepdim=True)
        one_hot = torch.zeros_like(logits).scatter_(-1, sel, 1).unsqueeze(-2)
        locs = (locs * one_hot).sum(-1)
        log_scales = (log_scales * one_hot).sum(-1)
        coeffs = (coeffs * one_hot).sum(-1)
        unif = torch.rand_like(log_scales) * (1. - 2 * 1e-5) + 1e-5
        logistic_noise = torch.log(unif) - torch.log1p(-unif)
        x = locs + torch.exp(log_scales) * logistic_noise
        # NOTE: sampling analogously to pixcnn++, which clamps first, unlike image transformer
        x0 = torch.clamp(x[..., 0], -1., 1.)
        x1 = torch.clamp(x[..., 1] + coeffs[..., 0] * x0, -1., 1.)
        x2 = torch.clamp(x[..., 2] + coeffs[..., 1] * x0 + coeffs[..., 2] * x1, -1., 1.)
        x = torch.stack([x0, x1, x2], -1)
        return x

    def sample_from_cat(self, logits, argmax=False):
        if argmax:
            sel = torch.argmax(logits, -1, keepdim=False).float() / 255.
        else:
            gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) * (1. - 2 * 1e-5) + 1e-5))
            sel = torch.argmax(logits + gumbel_noise, -1, keepdim=False).float() / 255.
        return sel

    def sample(self, n, device, argmax=False):
        total_len = (self.hparams.image_size ** 2)
        if self.hparams.distr == "cat":
            total_len *= self.hparams.channels
        samples = torch.zeros((n, 3)).to(device)
        for curr_infer_length in tqdm(range(total_len)):
            outputs = self.forward(samples, sampling=True)
            outputs = outputs.view(n, -1, outputs.shape[-1])[:,curr_infer_length:curr_infer_length+1,:]
            if self.hparams.distr == "dmol":
                x = self.sample_from_dmol(outputs).squeeze()
            elif self.hparams.distr == "cat":
                x = self.sample_from_cat(outputs, argmax=argmax)
            if curr_infer_length == 0:
                samples = x
            else:
                samples = torch.cat([samples, x], 1)
        samples = samples.view(n, self.hparams.image_size, self.hparams.image_size, self.hparams.channels)
        samples = samples.permute(0, 3, 1, 2)
        return samples

    def sample_from_preds(self, preds, argmax=False):
        if self.hparams.distr == "dmol":
            samples = self.sample_from_dmol(preds)
            samples = samples.permute(0, 3, 1, 2)
        elif self.hparams.distr == "cat":
            samples = self.sample_from_cat(preds, argmax=argmax)
        return samples

DecoderLayer

In [25]:
class DecoderLayer(nn.Module):
    """Implements a single layer of an unconditional ImageTransformer"""
    def __init__(self, hparams):
        super().__init__()
        self.attn = Attn(hparams)
        self.hparams = hparams
        self.dropout = nn.Dropout(p=hparams.dropout)
        self.layernorm_attn = nn.LayerNorm([self.hparams.hidden_size], eps=1e-6, elementwise_affine=True)
        self.layernorm_ffn = nn.LayerNorm([self.hparams.hidden_size], eps=1e-6, elementwise_affine=True)
        self.ffn = nn.Sequential(nn.Linear(self.hparams.hidden_size, self.hparams.filter_size, bias=True),
                                 nn.ReLU(),
                                 nn.Linear(self.hparams.filter_size, self.hparams.hidden_size, bias=True))

    def preprocess_(self, X):
        return X

    # Takes care of the "postprocessing" from tensorflow code with the layernorm and dropout
    def forward(self, X):
        X = self.preprocess_(X)
        y = self.attn(X)
        X = self.layernorm_attn(self.dropout(y) + X)
        y = self.ffn(self.preprocess_(X))
        X = self.layernorm_ffn(self.dropout(y) + X)
        return X

Attn

In [26]:
class Attn(nn.Module):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.kd = self.hparams.total_key_depth or self.hparams.hidden_size
        self.vd = self.hparams.total_value_depth or self.hparams.hidden_size
        self.q_dense = nn.Linear(self.hparams.hidden_size, self.kd, bias=False)
        self.k_dense = nn.Linear(self.hparams.hidden_size, self.kd, bias=False)
        self.v_dense = nn.Linear(self.hparams.hidden_size, self.vd, bias=False)
        self.output_dense = nn.Linear(self.vd, self.hparams.hidden_size, bias=False)
        assert self.kd % self.hparams.num_heads == 0
        assert self.vd % self.hparams.num_heads == 0

    def dot_product_attention(self, q, k, v, bias=None):
        logits = torch.einsum("...kd,...qd->...qk", k, q)
        if bias is not None:
            logits += bias
        weights = F.softmax(logits, dim=-1)
        return weights @ v

    def forward(self, X):
        q = self.q_dense(X)
        k = self.k_dense(X)
        v = self.v_dense(X)
        # Split to shape [batch_size, num_heads, len, depth / num_heads]
        q = q.view(q.shape[:-1] + (self.hparams.num_heads, self.kd // self.hparams.num_heads)).permute([0, 2, 1, 3])
        k = k.view(k.shape[:-1] + (self.hparams.num_heads, self.kd // self.hparams.num_heads)).permute([0, 2, 1, 3])
        v = v.view(v.shape[:-1] + (self.hparams.num_heads, self.vd // self.hparams.num_heads)).permute([0, 2, 1, 3])
        q *= (self.kd // self.hparams.num_heads) ** (-0.5)

        if self.hparams.attn_type == "global":
            bias = -1e9 * torch.triu(torch.ones(X.shape[1], X.shape[1]), 1).to(X.device)
            result = self.dot_product_attention(q, k, v, bias=bias)
        elif self.hparams.attn_type == "local_1d":
            len = X.shape[1]
            blen = self.hparams.block_length
            pad = (0, 0, 0, (-len) % self.hparams.block_length) # Append to multiple of block length
            q = F.pad(q, pad)
            k = F.pad(k, pad)
            v = F.pad(v, pad)

            bias = -1e9 * torch.triu(torch.ones(blen, blen), 1).to(X.device)
            first_output = self.dot_product_attention(
                q[:,:,:blen,:], k[:,:,:blen,:], v[:,:,:blen,:], bias=bias)

            if q.shape[2] > blen:
                q = q.view(q.shape[0], q.shape[1], -1, blen, q.shape[3])
                k = k.view(k.shape[0], k.shape[1], -1, blen, k.shape[3])
                v = v.view(v.shape[0], v.shape[1], -1, blen, v.shape[3])
                local_k = torch.cat([k[:,:,:-1], k[:,:,1:]], 3) # [batch, nheads, (nblocks - 1), blen * 2, depth]
                local_v = torch.cat([v[:,:,:-1], v[:,:,1:]], 3)
                tail_q = q[:,:,1:]
                bias = -1e9 * torch.triu(torch.ones(blen, 2 * blen), blen + 1).to(X.device)
                tail_output = self.dot_product_attention(tail_q, local_k, local_v, bias=bias)
                tail_output = tail_output.view(tail_output.shape[0], tail_output.shape[1], -1, tail_output.shape[4])
                result = torch.cat([first_output, tail_output], 2)
                result = result[:,:,:X.shape[1],:]
            else:
                result = first_output[:,:,:X.shape[1],:]

        result = result.permute([0, 2, 1, 3]).contiguous()
        result = result.view(result.shape[0:2] + (-1,))
        result = self.output_dense(result)
        return result

### Configuration

In [27]:
def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

def parse_args_and_config():
    """
    :return args, config: namespace objects that stores information in args and config files.
    """
    parser = argparse.ArgumentParser(description=globals()['__doc__'])

    parser.add_argument('--config', type=str, default='transformer_tiny.yml', help='Path to the config file')
    parser.add_argument('--doc', type=str, default='0', help='A string for documentation purpose')
    parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical')
    parser.add_argument('--sample', action='store_true', help='Sample at train time')

    args, unknown = parser.parse_known_args() #special modification for Jupyter notebooks. 
    #args = parser.parse_args()

    args.log = os.path.join('transformer_logs', args.doc)
    # parse config file
    with open(os.path.join('configs', args.config), 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    new_config = dict2namespace({**config, **vars(args)})

    if os.path.exists(args.log):
        shutil.rmtree(args.log)

    os.makedirs(args.log)

    with open(os.path.join(args.log, 'config.yml'), 'w') as f:
        yaml.dump(new_config, f, default_flow_style=False)

    # setup logger
    level = getattr(logging, args.verbose.upper(), None)
    if not isinstance(level, int):
        raise ValueError('level {} not supported'.format(args.verbose))

    handler1 = logging.StreamHandler()
    handler2 = logging.FileHandler(os.path.join(args.log, 'stdout.txt'))
    formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
    handler1.setFormatter(formatter)
    handler2.setFormatter(formatter)
    logger = logging.getLogger()
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    logger.setLevel(level)

    # add device information to args
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    logging.info("Using device: {}".format(device))
    new_config.device = device

    # set random seed
    torch.manual_seed(new_config.seed)
    torch.cuda.manual_seed_all(new_config.seed)
    np.random.seed(new_config.seed)
    logging.info("Run name: {}".format(args.doc))

    return args, new_config

def get_lr(step, config):
    warmup_steps = config.optim.warmup
    lr_base = config.optim.lr * 0.002 # for Adam correction
    ret = 5000. * config.model.hidden_size ** (-0.5) * \
          np.min([(step + 1) * warmup_steps ** (-1.5), (step + 1) ** (-0.5)])
    return ret * lr_base

### Load Dataset

In [28]:
args, config = parse_args_and_config()

if config.model.distr == "dmol":
    # Scale size and rescale data to [-1, 1]
    transform = transforms.Compose([
        transforms.Resize(config.model.image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
                            std=[0.5, 0.5, 0.5])
    ])
else:
    transform = transforms.Compose([
        transforms.Resize(config.model.image_size),
        transforms.ToTensor()
    ])

dataset = datasets.CIFAR10('datasets/transformer', transform=transform, download=True)
loader = DataLoader(dataset, batch_size=config.train.batch_size, shuffle=True, num_workers=4)
input_dim = config.model.image_size ** 2 * config.model.channels
model = ImageTransformer(config.model).to(config.device)
optimizer = optim.Adam(model.parameters(), lr=1., betas=(0.9, 0.98), eps=1e-9)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: get_lr(step, config))

INFO - <ipython-input-27-fea14b691037> - 2022-06-27 23:27:49,857 - Using device: cuda
INFO - <ipython-input-27-fea14b691037> - 2022-06-27 23:27:49,857 - Using device: cuda
INFO - <ipython-input-27-fea14b691037> - 2022-06-27 23:27:49,857 - Using device: cuda
INFO - <ipython-input-27-fea14b691037> - 2022-06-27 23:27:49,862 - Run name: 0
INFO - <ipython-input-27-fea14b691037> - 2022-06-27 23:27:49,862 - Run name: 0
INFO - <ipython-input-27-fea14b691037> - 2022-06-27 23:27:49,862 - Run name: 0


Files already downloaded and verified


### Training Loop

In [29]:
gain = config.model.initializer_gain

for name, p in model.named_parameters():
    if "layernorm" in name:
        continue
    if p.dim() > 1:
        nn.init.xavier_uniform_(p, gain=np.sqrt(gain)) # Need sqrt for inconsistency between pytorch / TF
    else:
        a =  np.sqrt(3. * gain / p.shape[0])
        nn.init.uniform_(p, -a, a)

def revert_samples(input):
    if config.model.distr == "cat":
        return input
    elif config.model.distr == "dmol":
        return input * 0.5 + 0.5

step = 0
losses_per_dim = torch.zeros(config.model.channels, config.model.image_size, config.model.image_size).to(config.device)

for _ in range(config.train.epochs):
        for _, (imgs, l) in enumerate(loader):
            imgs = imgs.to(config.device)
            model.train()

            scheduler.step()
            optimizer.zero_grad()
            preds = model(imgs)
            loss = model.loss(preds, imgs)
            decay = 0. if step == 0 else 0.99
            if config.model.distr == "dmol":
                losses_per_dim[0,:,:] = losses_per_dim[0,:,:] * decay + (1 - decay) * loss.detach().mean(0) / np.log(2)
            else:
                losses_per_dim = losses_per_dim * decay + (1 - decay) * loss.detach().mean(0) / np.log(2)
            loss = loss.view(loss.shape[0], -1).sum(1)
            loss = loss.mean(0)

            # Show computational graph
            # dot = make_dot(loss, dict(model.named_parameters()))
            # dot.render('test.gv', view=True)

            loss.backward()

            total_norm = 0
            for p in model.parameters():
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
            total_norm = (total_norm ** (1. / 2))

            if config.train.clip_grad_norm > 0.0:
                nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad_norm)

            total_norm_post = 0
            for p in model.parameters():
                param_norm = p.grad.data.norm(2)
                total_norm_post += param_norm.item() ** 2
            total_norm_post = (total_norm_post ** (1. / 2))

            optimizer.step()
            bits_per_dim = loss / (np.log(2.) * input_dim)
            acc = model.accuracy(preds, imgs)

            if step % config.train.log_iter == 0:
                logging.info('step: {}; loss: {:.3f}; bits_per_dim: {:.3f}, acc: {:.3f}, grad norm pre: {:.3f}, post: {:.3f}'
                             .format(step, loss.item(), bits_per_dim.item(), acc.item(), total_norm, total_norm_post))
                tb_logger.add_scalar('loss', loss.item(), global_step=step)
                tb_logger.add_scalar('bits_per_dim', bits_per_dim.item(), global_step=step)
                tb_logger.add_scalar('acc', acc.item(), global_step=step)
                tb_logger.add_scalar('grad_norm', total_norm, global_step=step)

            if step % config.train.sample_iter == 0:
                logging.info("Sampling from model: {}".format(args.doc))
                if config.model.distr == "cat":
                    channels = ['r','g','b']
                    color_codes = ['Reds', "Greens", 'Blues']
                    for idx, c in enumerate(channels):
                        ax = sns.heatmap(losses_per_dim[idx,:,:].cpu().numpy(), linewidth=0.5, cmap=color_codes[idx])
                        tb_logger.add_figure("losses_per_dim/{}".format(c), ax.get_figure(), close=True, global_step=step)
                else:
                    ax = sns.heatmap(losses_per_dim[0,:,:].cpu().numpy(), linewidth=0.5, cmap='Blues')
                    tb_logger.add_figure("losses_per_dim", ax.get_figure(), close=True, global_step=step)

                model.eval()
                with torch.no_grad():
                    imgs = revert_samples(imgs)
                    imgs_grid = torchvision.utils.make_grid(imgs[:8, ...], 3)
                    tb_logger.add_image('imgs', imgs_grid, global_step=step)

                    # Evaluate model predictions for the input
                    pred_samples = revert_samples(model.sample_from_preds(preds))
                    pred_samples_grid = torchvision.utils.make_grid(pred_samples[:8, ...], 3)
                    tb_logger.add_image('pred_samples/random', pred_samples_grid, global_step=step)
                    pred_samples = revert_samples(model.sample_from_preds(preds, argmax=True))
                    pred_samples_grid = torchvision.utils.make_grid(pred_samples[:8, ...], 3)
                    tb_logger.add_image('pred_samples/argmax', pred_samples_grid, global_step=step)

                    if args.sample:
                        samples = revert_samples(model.sample(config.train.sample_size, config.device))
                        samples_grid = torchvision.utils.make_grid(samples[:8, ...], 3)
                        tb_logger.add_image('samples', samples_grid, global_step=step)

                    # Argmax samples are not useful for unconditional generation
                    # if config.model.distr == "cat":
                    #     argmax_samples = model.sample(1, config.device, argmax=True)
                    #     samples_grid = torchvision.utils.make_grid(argmax_samples[:8, ...], 3)
                    #     tb_logger.add_image('argmax_samples', samples_grid, global_step=step)
                torch.save(model.state_dict(), os.path.join('transformer_logs', args.doc, "model.pth"))
            step += 1

INFO - <ipython-input-29-44ee206a9f55> - 2022-06-27 23:27:52,070 - step: 0; loss: 1072.975; bits_per_dim: 8.062, acc: 0.004, grad norm pre: 78.396, post: 78.396
INFO - <ipython-input-29-44ee206a9f55> - 2022-06-27 23:27:52,070 - step: 0; loss: 1072.975; bits_per_dim: 8.062, acc: 0.004, grad norm pre: 78.396, post: 78.396
INFO - <ipython-input-29-44ee206a9f55> - 2022-06-27 23:27:52,070 - step: 0; loss: 1072.975; bits_per_dim: 8.062, acc: 0.004, grad norm pre: 78.396, post: 78.396
INFO - <ipython-input-29-44ee206a9f55> - 2022-06-27 23:27:52,074 - Sampling from model: 0
INFO - <ipython-input-29-44ee206a9f55> - 2022-06-27 23:27:52,074 - Sampling from model: 0
INFO - <ipython-input-29-44ee206a9f55> - 2022-06-27 23:27:52,074 - Sampling from model: 0
INFO - <ipython-input-29-44ee206a9f55> - 2022-06-27 23:27:53,016 - step: 10; loss: 1072.973; bits_per_dim: 8.062, acc: 0.005, grad norm pre: 88.140, post: 88.140
INFO - <ipython-input-29-44ee206a9f55> - 2022-06-27 23:27:53,016 - step: 10; loss: 10

KeyboardInterrupt: 