In [37]:
import os
import sys
import gc

import time
import json
import random
import math
import numpy as np

import torch
from torch.optim.adamw import AdamW
from torch.nn.utils import clip_grad_norm_
import torch.distributed as dist
from torch.utils.data.dataloader import DataLoader
from torchvision.utils import save_image


from wolf.data import load_datasets, get_batch, preprocess, postprocess
from wolf import WolfModel
from wolf.utils import total_grad_norm
from wolf.optim import ExponentialScheduler

from experiments.options import parse_args

import autoreload

%autoreload 2

In [23]:
def is_master(rank):
    return rank <= 0

In [24]:
def is_distributed(rank):
    return rank >= 0

In [25]:
def logging(info, logfile=None):
    print(info)
    if logfile is not None:
        print(info, file=logfile)
        logfile.flush()

In [26]:
def get_optimizer(learning_rate, parameters, betas, eps, amsgrad, step_decay, weight_decay, warmup_steps, init_lr):
    optimizer = AdamW(parameters, lr=learning_rate, betas=betas, eps=eps, amsgrad=amsgrad, weight_decay=weight_decay)
    step_decay = step_decay
    scheduler = ExponentialScheduler(optimizer, step_decay, warmup_steps, init_lr)
    return optimizer, scheduler

In [27]:
def setup(args):
    def check_dataset():
        if dataset == 'cifar10':
            assert image_size == 32, 'CIFAR-10 expected image size 32 but got {}'.format(image_size)
        elif dataset.startswith('lsun'):
            assert image_size in [128, 256]
        elif dataset == 'celeba':
            assert image_size in [256, 512]
        elif dataset == 'imagenet':
            assert image_size in [64, 128, 256]

    dataset = args.dataset
    if args.category is not None:
        dataset = dataset + '_' + args.category
    image_size = args.image_size
    check_dataset()

    nc = 3
    args.nx = image_size ** 2 * nc
    n_bits = args.n_bits
    args.n_bins = 2. ** n_bits
    args.test_k = 5

    model_path = args.model_path
    args.checkpoint_name = os.path.join(model_path, 'checkpoint')

    result_path = os.path.join(model_path, 'images')
    args.result_path = result_path
    data_path = args.data_path

    if is_master(args.rank):
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        if not os.path.exists(result_path):
            os.makedirs(result_path)
        if args.recover < 0:
            args.log = open(os.path.join(model_path, 'log.txt'), 'w')
        else:
            args.log = open(os.path.join(model_path, 'log.txt'), 'a')
    else:
        args.log = None

    args.cuda = torch.cuda.is_available()
    random_seed = args.seed + args.rank if args.rank >= 0 else args.seed
    if args.recover >= 0:
        random_seed += random.randint(0, 1024)
    logging("Rank {}: random seed={}".format(args.rank, random_seed), logfile=args.log)
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    device = torch.device('cuda', args.local_rank) if args.cuda else torch.device('cpu')
    if args.cuda:
        torch.cuda.set_device(device)
        torch.cuda.manual_seed(random_seed)

    torch.backends.cudnn.benchmark = True

    args.world_size = int(os.environ["WORLD_SIZE"]) if is_distributed(args.rank) else 1
    logging("Rank {}: ".format(args.rank) + str(args), args.log)

    train_data, val_data = load_datasets(dataset, image_size, data_path=data_path)
    train_index = np.arange(len(train_data))
    np.random.shuffle(train_index)
    val_index = np.arange(len(val_data))

    if is_master(args.rank):
        logging('Data size: training: {}, val: {}'.format(len(train_index), len(val_index)))

    if args.recover >= 0:
        params = json.load(open(os.path.join(model_path, 'config.json'), 'r'))
    else:
        params = json.load(open(args.config, 'r'))
        json.dump(params, open(os.path.join(model_path, 'config.json'), 'w'), indent=2)

    wolf = WolfModel.from_params(params)
    wolf.to_device(device)
    args.device = device

    return args, (train_data, val_data), (train_index, val_index), wolf

In [28]:
def init_dataloader(args, train_data, val_data):
    if is_distributed(args.rank):
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_data, rank=args.rank,
                                                                        num_replicas=args.world_size,
                                                                        shuffle=True)
    else:
        train_sampler = None
    train_loader = DataLoader(train_data, batch_size=args.batch_size,
                              shuffle=(train_sampler is None), sampler=train_sampler,
                              num_workers=args.workers, pin_memory=True, drop_last=True)
    if is_master(args.rank):
        eval_batch = args.eval_batch_size
        val_loader = DataLoader(val_data, batch_size=eval_batch, shuffle=False,
                                num_workers=args.workers, pin_memory=True)
    else:
        val_loader = None

    return train_loader, train_sampler, val_loader

In [29]:
def init_model(args, train_data, train_index, wolf):
    wolf.eval()
    init_batch_size = args.init_batch_size
    logging('Rank {}, init model: {} instances'.format(args.rank, init_batch_size), args.log)
    init_index = np.random.choice(train_index, init_batch_size, replace=False)
    init_x, init_y = get_batch(train_data, init_index)
    init_x = preprocess(init_x.to(args.device), args.n_bits)
    init_y = init_y.to(args.device)
    wolf.init(init_x, y=init_y, init_scale=1.0)

In [30]:
def reconstruct(args, epoch, val_data, val_index, wolf):
    logging('reconstruct', args.log)
    wolf.eval()
    n = 16
    np.random.shuffle(val_index)
    img, y = get_batch(val_data, val_index[:n])
    img = img.to(args.device)
    y = y.to(args.device)

    z, epsilon = wolf.encode(img, y=y, n_bits=args.n_bits, random=False)
    epsilon = epsilon.squeeze(1)
    z = z.squeeze(1) if z is not None else z
    img_recon = wolf.decode(epsilon, z=z, n_bits=args.n_bits)

    img = postprocess(preprocess(img, args.n_bits), args.n_bits)
    abs_err = img_recon.add(img * -1).abs()
    logging('Err: {:.4f}, {:.4f}'.format(abs_err.max().item(), abs_err.mean().item()), args.log)

    comparison = torch.cat([img, img_recon], dim=0).cpu()
    reorder_index = torch.from_numpy(np.array([[i + j * n for j in range(2)] for i in range(n)])).view(-1)
    comparison = comparison[reorder_index]
    image_file = 'reconstruct{}.png'.format(epoch)
    save_image(comparison, os.path.join(args.result_path, image_file), nrow=16)

In [31]:
def sample(args, epoch, wolf):
    logging('sampling', args.log)
    wolf.eval()
    n = 64 if args.image_size > 128 else 256
    nrow = int(math.sqrt(n))
    taus = [0.7, 0.8, 0.9, 1.0]
    start_time = time.time()
    image_size = (3, args.image_size, args.image_size)
    for t in taus:
        imgs = wolf.synthesize(n, image_size, tau=t, n_bits=args.n_bits, device=args.device)
        image_file = 'sample{}.t{:.1f}.png'.format(epoch, t)
        save_image(imgs, os.path.join(args.result_path, image_file), nrow=nrow)
    logging('time: {:.1f}s'.format(time.time() - start_time), args.log)

In [32]:
def eval(args, val_loader, wolf):
    wolf.eval()
    wolf.sync()
    gnll = 0
    nent = 0
    kl = 0
    num_insts = 0
    device = args.device
    n_bits = args.n_bits
    n_bins = args.n_bins
    nx = args.nx
    test_k = args.test_k
    for data, y in val_loader:
        batch_size = len(data)
        data = data.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        loss_gen, loss_kl, loss_dequant = wolf.loss(data, y=y, n_bits=n_bits, nsamples=test_k)
        gnll += loss_gen.sum().item()
        kl += loss_kl.sum().item()
        nent += loss_dequant.sum().item()
        num_insts += batch_size

    gnll = gnll / num_insts
    nent = nent / num_insts
    kl = kl / num_insts
    nll = gnll + kl + nent + np.log(n_bins / 2.) * nx
    bpd = nll / (nx * np.log(2.0))
    nepd = nent / (nx * np.log(2.0))
    logging('Avg  NLL: {:.2f}, KL: {:.2f}, NENT: {:.2f}, BPD: {:.4f}, NEPD: {:.4f}'.format(
        nll, kl, nent, bpd, nepd), args.log)
    return nll, kl, nent, bpd, nepd

In [33]:
def train(args, train_loader, train_index, train_sampler, val_loader, val_data, val_index, wolf):
    epochs = args.epochs
    train_k = args.train_k
    n_bits = args.n_bits
    n_bins = args.n_bins
    nx = args.nx
    grad_clip = args.grad_clip
    batch_steps = args.batch_steps

    steps_per_checkpoint = 1000

    device = args.device
    log = args.log

    lr_warmups = args.warmup_steps
    init_lr = 1e-7
    betas = (args.beta1, args.beta2)
    eps = args.eps
    amsgrad = args.amsgrad
    lr_decay = args.lr_decay
    weight_decay = args.weight_decay

    optimizer, scheduler = get_optimizer(args.lr, wolf.parameters(), betas, eps,
                                         amsgrad=amsgrad, step_decay=lr_decay,
                                         weight_decay=weight_decay,
                                         warmup_steps=lr_warmups, init_lr=init_lr)
    if args.recover >= 0:
        checkpoint_name = args.checkpoint_name + '{}.tar'.format(args.recover)
        print(f"Rank = {args.rank}, loading from checkpoint {checkpoint_name}")

        checkpoint = torch.load(checkpoint_name, map_location=args.device)
        start_epoch = checkpoint['epoch']
        last_step = checkpoint['step']
        wolf.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])

        best_epoch = checkpoint['best_epoch']
        best_nll = checkpoint['best_nll']
        best_bpd = checkpoint['best_bpd']
        best_nent = checkpoint['best_nent']
        best_nepd = checkpoint['best_nepd']
        best_kl = checkpoint['best_kl']
        del checkpoint
        if is_master(args.rank):
            with torch.no_grad():
                logging('Evaluating after resuming model...', log)
                eval(args, val_loader, wolf)
    else:
        start_epoch = 1
        last_step = -1
        best_epoch = 0
        best_nll = 1e12
        best_bpd = 1e12
        best_nent = 1e12
        best_nepd = 1e12
        best_kl = 1e12

    for epoch in range(start_epoch, epochs + 1):
        wolf.train()
        if is_distributed(args.rank):
            train_sampler.set_epoch(epoch)

        lr = scheduler.get_lr()[0]
        start_time = time.time()
        if is_master(args.rank):
            logging('Epoch: %d (lr=%.6f, betas=(%.1f, %.3f), eps=%.1e, amsgrad=%s, lr decay=%.6f, clip=%.1f, l2=%.1e, train_k=%d)' % (
            epoch, lr, betas[0], betas[1], eps, amsgrad, lr_decay, grad_clip, weight_decay, train_k), log)

        gnll = torch.Tensor([0.]).to(device)
        kl = torch.Tensor([0.]).to(device)
        nent = torch.Tensor([0.]).to(device)
        num_insts = torch.Tensor([0.]).to(device)
        num_back = 0
        num_nans = 0
        if args.cuda:
            torch.cuda.empty_cache()
        gc.collect()
        
        # data: [batch_size, n_channel, H, W]
        # labels: [batch_size]
        for step, (data, y) in enumerate(train_loader):
            if step <= last_step:
                continue
            last_step = -1
            optimizer.zero_grad()
            batch_size = len(data)
            data = data.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            data_list = [data,] if batch_steps == 1 else data.chunk(batch_steps, dim=0)
            y_list = [y,] if batch_steps == 1 else y.chunk(batch_steps, dim=0)

            gnll_batch = 0
            kl_batch = 0
            nent_batch = 0
            # disable allreduce for accumulated gradient.
            if is_distributed(args.rank):
                wolf.disable_allreduce()
            for data, y in zip (data_list[:-1], y_list[:-1]):
                loss_gen, loss_kl, loss_dequant = wolf.loss(data, y=y, n_bits=n_bits, nsamples=train_k)
                loss_gen = loss_gen.sum()
                loss_kl = loss_kl.sum()
                loss_dequant = loss_dequant.sum()
                loss = (loss_gen + loss_kl + loss_dequant) / batch_size
                loss.backward()
                with torch.no_grad():
                    gnll_batch += loss_gen.item()
                    kl_batch += loss_kl.item()
                    nent_batch += loss_dequant.item()
            # enable allreduce for the last step.
            if is_distributed(args.rank):
                wolf.enable_allreduce()
            data, y = data_list[-1], y_list[-1]
            loss_gen, loss_kl, loss_dequant = wolf.loss(data, y=y, n_bits=n_bits, nsamples=train_k)
            loss_gen = loss_gen.sum()
            loss_kl = loss_kl.sum()
            loss_dequant = loss_dequant.sum()
            loss = (loss_gen + loss_kl + loss_dequant) / batch_size
            loss.backward()
            with torch.no_grad():
                gnll_batch += loss_gen.item()
                kl_batch += loss_kl.item()
                nent_batch += loss_dequant.item()

            if grad_clip > 0:
                grad_norm = clip_grad_norm_(wolf.parameters(), grad_clip)
            else:
                grad_norm = total_grad_norm(wolf.parameters())

            if math.isnan(grad_norm):
                num_nans += 1
            else:
                optimizer.step()
                scheduler.step()
                num_insts += batch_size
                gnll += gnll_batch
                kl += kl_batch
                nent += nent_batch

            if step % 10 == 0:
                torch.cuda.empty_cache()

            if step % args.log_interval == 0 and is_master(args.rank):
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                nums = max(num_insts.item(), 1)
                train_gnll = gnll.item() / nums
                train_kl = kl.item() / nums
                train_nent = nent.item() / nums
                train_nll = train_gnll + train_kl + train_nent + np.log(n_bins / 2.) * nx
                bits_per_pixel = train_nll / (nx * np.log(2.0))
                nent_per_pixel = train_nent / (nx * np.log(2.0))
                curr_lr = scheduler.get_lr()[0]
                log_info = '[{}/{} ({:.0f}%) lr={:.6f}, {}] NLL: {:.2f}, BPD: {:.4f}, KL: {:.2f}, NENT: {:.2f}, NEPD: {:.4f}'.format(
                    step * batch_size * args.world_size, len(train_index),
                    100. * step * batch_size * args.world_size / len(train_index), curr_lr, num_nans,
                    train_nll, bits_per_pixel, train_kl, train_nent, nent_per_pixel)

                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

            if step > 0 and step % steps_per_checkpoint == 0 and is_master(args.rank):
                # save checkpoint
                checkpoint_name = args.checkpoint_name + '{}.tar'.format(step)
                torch.save({'epoch': epoch,
                            'step': step,
                            'model': wolf.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'best_epoch': best_epoch,
                            'best_nll': best_nll,
                            'best_bpd': best_bpd,
                            'best_kl': best_kl,
                            'best_nent': best_nent,
                            'best_nepd': best_nepd},
                           checkpoint_name)

        if is_distributed(args.rank):
            dist.reduce(gnll, dst=0, op=dist.ReduceOp.SUM)
            dist.reduce(kl, dst=0, op=dist.ReduceOp.SUM)
            dist.reduce(nent, dst=0, op=dist.ReduceOp.SUM)
            dist.reduce(num_insts, dst=0, op=dist.ReduceOp.SUM)

        if is_master(args.rank):
            sys.stdout.write("\b" * num_back)
            sys.stdout.write(" " * num_back)
            sys.stdout.write("\b" * num_back)
            nums = num_insts.item()
            train_gnll = gnll.item() / nums
            train_kl = kl.item() / nums
            train_nent = nent.item() / nums
            train_nll = train_gnll + train_kl + train_nent + np.log(n_bins / 2.) * nx
            bits_per_pixel = train_nll / (nx * np.log(2.0))
            nent_per_pixel = train_nent / (nx * np.log(2.0))
            logging('Average NLL: {:.2f}, BPD: {:.4f}, KL: {:.2f}, NENT: {:.2f}, NEPD: {:.4f}, time: {:.1f}s'.format(
                    train_nll, bits_per_pixel, train_kl, train_nent, nent_per_pixel, time.time() - start_time), log)
            logging('-' * 125, log)

            if epoch < args.valid_epochs or epoch % args.valid_epochs == 0:
                with torch.no_grad():
                    nll, kl, nent, bpd, nepd = eval(args, val_loader, wolf)
                    if nll < best_nll:
                        best_epoch = epoch
                        best_nll = nll
                        best_bpd = bpd
                        best_kl = kl
                        best_nent = nent
                        best_nepd = nepd
                        wolf.save(args.model_path)
                        checkpoint_name = args.checkpoint_name + '{}.tar'.format(0)
                        torch.save({'epoch': epoch + 1,
                                    'step': -1,
                                    'model': wolf.state_dict(),
                                    'optimizer': optimizer.state_dict(),
                                    'scheduler': scheduler.state_dict(),
                                    'best_epoch': best_epoch,
                                    'best_nll': best_nll,
                                    'best_bpd': best_bpd,
                                    'best_kl': best_kl,
                                    'best_nent': best_nent,
                                    'best_nepd': best_nepd},
                                   checkpoint_name)
                    try:
                        reconstruct(args, epoch, val_data, val_index, wolf)
                    except RuntimeError:
                        print('Reconstruction failed.')
                    try:
                        sample(args, epoch, wolf)
                    except RuntimeError:
                        print('Sampling failed')
            logging('Best NLL: {:.2f}, KL: {:.2f}, NENT: {:.2f}, BPD: {:.4f}, NEPD: {:.4f}, epoch: {}'.format(
                best_nll, best_kl, best_nent, best_bpd, best_nepd, best_epoch), log)
            logging('=' * 125, log)
            # save checkpoint
            checkpoint_name = args.checkpoint_name + '{}.tar'.format(1)
            torch.save({'epoch': epoch + 1,
                        'step': -1,
                        'model': wolf.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'best_epoch': best_epoch,
                        'best_nll': best_nll,
                        'best_bpd': best_bpd,
                        'best_kl': best_kl,
                        'best_nent': best_nent,
                        'best_nepd': best_nepd},
                       checkpoint_name)

In [34]:
def main(args):
    args, (train_data, val_data), (train_index, val_index), wolf = setup(args)

    if is_master(args.rank):
        logging('# of Parameters: %d' % sum([param.numel() for param in wolf.parameters()]), args.log)
        if args.recover < 0:
            init_model(args, train_data, train_index, wolf)
            wolf.sync()

    if is_distributed(args.rank):
        wolf.init_distributed(args.rank, args.local_rank)

    train_loader, train_sampler, val_loader = init_dataloader(args, train_data, val_data)

    train(args, train_loader, train_index, train_sampler, val_loader, val_data, val_index, wolf)

In [35]:
args_dict = {'rank': -1,
 'local_rank': 0,
 'config': 'experiments/configs/cifar10/glow/glow-cat-uni.json',
 'batch_size': 256,
 'eval_batch_size': 1000,
 'batch_steps': 2,
 'init_batch_size': 1024,
 'epochs': 100,
 'valid_epochs': 10,
 'seed': 65537,
 'train_k': 1,
 'log_interval': 10,
 'lr': 0.001,
 'warmup_steps': 50,
 'lr_decay': 0.999997,
 'beta1': 0.9,
 'beta2': 0.999,
 'eps': 1e-08,
 'weight_decay': 1e-06,
 'amsgrad': False,
 'grad_clip': 0.0,
 'dataset': 'cifar10',
 'category': None,
 'image_size': 32,
 'workers': 4,
 'n_bits': 8,
 'model_path': 'experiments/save_model',
 'data_path': 'experiments/cifar_data',
 'recover': -1}

from argparse import Namespace

args = Namespace(**args_dict)

In [36]:
assert args.rank == -1 and args.local_rank == 0, 'single process should have wrong rank ({}) or local rank ({})'.format(args.rank, args.local_rank)
main(args)

Rank -1: random seed=65537
Rank -1: Namespace(amsgrad=False, batch_size=256, batch_steps=2, beta1=0.9, beta2=0.999, category=None, checkpoint_name='experiments/save_model/checkpoint', config='experiments/configs/cifar10/glow/glow-cat-uni.json', cuda=True, data_path='experiments/cifar_data', dataset='cifar10', epochs=100, eps=1e-08, eval_batch_size=1000, grad_clip=0.0, image_size=32, init_batch_size=1024, local_rank=0, log=<_io.TextIOWrapper name='experiments/save_model/log.txt' mode='w' encoding='UTF-8'>, log_interval=10, lr=0.001, lr_decay=0.999997, model_path='experiments/save_model', n_bins=256.0, n_bits=8, nx=3072, rank=-1, recover=-1, result_path='experiments/save_model/images', seed=65537, test_k=5, train_k=1, valid_epochs=10, warmup_steps=50, weight_decay=1e-06, workers=4, world_size=1)
Files already downloaded and verified
Data size: training: 50000, val: 10000
# of Parameters: 50494081
Rank -1, init model: 1024 instances
+++
tensor([[[[0.4431, 0.5529, 0.5490,  ..., 0.4039, 0.5

Epoch: 1 (lr=0.000000, betas=(0.9, 0.999), eps=1.0e-08, amsgrad=False, lr decay=0.999997, clip=0.0, l2=1.0e-06, train_k=1)
+++
tensor([[[[0.9255, 0.9020, 0.8863,  ..., 0.4078, 0.4235, 0.5804],
          [0.8863, 0.9176, 0.8824,  ..., 0.3843, 0.3686, 0.3922],
          [0.5529, 0.6157, 0.6000,  ..., 0.3725, 0.3725, 0.3765],
          ...,
          [0.3882, 0.3216, 0.3176,  ..., 0.2549, 0.2157, 0.2039],
          [0.3843, 0.3373, 0.3020,  ..., 0.3765, 0.3569, 0.3137],
          [0.3686, 0.3255, 0.2980,  ..., 0.3098, 0.2784, 0.2588]],

         [[0.9373, 0.9176, 0.9059,  ..., 0.5569, 0.5765, 0.6863],
          [0.9098, 0.9255, 0.8980,  ..., 0.5490, 0.5490, 0.5529],
          [0.6510, 0.6902, 0.6784,  ..., 0.5373, 0.5412, 0.5451],
          ...,
          [0.3961, 0.3529, 0.3529,  ..., 0.2980, 0.2627, 0.2510],
          [0.4039, 0.3725, 0.3333,  ..., 0.4078, 0.3961, 0.3569],
          [0.3843, 0.3490, 0.3216,  ..., 0.3490, 0.3216, 0.3059]],

         [[0.9843, 0.9647, 0.9529,  ..., 0.7686

+++
tensor([[[[0.8863, 0.8510, 0.8941,  ..., 0.9686, 0.9569, 0.9451],
          [0.8863, 0.8745, 0.9059,  ..., 0.9529, 0.9529, 0.9843],
          [0.9020, 0.8980, 0.9373,  ..., 0.9137, 0.9176, 0.9373],
          ...,
          [0.8941, 0.8824, 0.8980,  ..., 0.9020, 0.9255, 0.9373],
          [0.8667, 0.8627, 0.8549,  ..., 0.9529, 0.9490, 0.9608],
          [0.9020, 0.8941, 0.8627,  ..., 0.9647, 0.9725, 0.9765]],

         [[0.8471, 0.8275, 0.8706,  ..., 0.9451, 0.9333, 0.9255],
          [0.8549, 0.8431, 0.8745,  ..., 0.9333, 0.9294, 0.9608],
          [0.8667, 0.8549, 0.8941,  ..., 0.8824, 0.8863, 0.9059],
          ...,
          [0.8627, 0.8471, 0.8627,  ..., 0.8784, 0.9020, 0.9137],
          [0.8353, 0.8314, 0.8235,  ..., 0.9176, 0.9137, 0.9255],
          [0.8706, 0.8627, 0.8314,  ..., 0.9333, 0.9412, 0.9412]],

         [[0.7804, 0.7412, 0.7843,  ..., 0.8706, 0.8588, 0.8510],
          [0.7843, 0.7569, 0.7922,  ..., 0.8471, 0.8549, 0.8863],
          [0.7961, 0.7765, 0.8196,  ..

[0/50000 (0%) lr=0.000020, 0] NLL: 12858.75, BPD: 6.0388, KL: 2.30, NENT: 0.00, NEPD: 0.0000+++
tensor([[[[0.4784, 0.3725, 0.5020,  ..., 0.2588, 0.2549, 0.3569],
          [0.8431, 0.6824, 0.5922,  ..., 0.2706, 0.2588, 0.3529],
          [0.4392, 0.3255, 0.5059,  ..., 0.2000, 0.1922, 0.2627],
          ...,
          [0.2863, 0.2431, 0.2275,  ..., 0.4275, 0.4353, 0.4588],
          [0.3725, 0.2471, 0.2275,  ..., 0.4510, 0.4353, 0.4549],
          [0.4118, 0.2431, 0.2353,  ..., 0.4588, 0.4431, 0.4667]],

         [[0.4157, 0.3059, 0.4314,  ..., 0.1804, 0.1647, 0.2078],
          [0.7922, 0.6157, 0.5216,  ..., 0.1922, 0.1804, 0.2196],
          [0.3804, 0.2588, 0.4353,  ..., 0.1373, 0.1294, 0.1412],
          ...,
          [0.1569, 0.1216, 0.1216,  ..., 0.3725, 0.3922, 0.4000],
          [0.2431, 0.1294, 0.1255,  ..., 0.3961, 0.3922, 0.3961],
          [0.2706, 0.1255, 0.1333,  ..., 0.4039, 0.3961, 0.4118]],

         [[0.4196, 0.2941, 0.3961,  ..., 0.2784, 0.2353, 0.2510],
          [0

+++
tensor([[[[0.7412, 0.9451, 0.9922,  ..., 0.6235, 0.9843, 0.9961],
          [0.7373, 0.9451, 0.9922,  ..., 0.6235, 0.9843, 0.9961],
          [0.7412, 0.9333, 0.9843,  ..., 0.6118, 0.9725, 0.9922],
          ...,
          [0.8784, 0.9686, 0.9765,  ..., 0.0667, 0.0745, 0.2471],
          [0.8784, 0.9451, 0.9490,  ..., 0.0667, 0.0588, 0.2784],
          [0.8824, 0.9529, 0.9647,  ..., 0.0392, 0.0431, 0.2745]],

         [[0.8039, 0.9608, 0.9922,  ..., 0.6275, 0.9804, 0.9961],
          [0.8078, 0.9647, 0.9922,  ..., 0.6275, 0.9804, 0.9961],
          [0.7961, 0.9608, 0.9961,  ..., 0.6235, 0.9725, 0.9922],
          ...,
          [0.9059, 0.9882, 0.9922,  ..., 0.0784, 0.1098, 0.3098],
          [0.9216, 0.9922, 0.9882,  ..., 0.0824, 0.0980, 0.3333],
          [0.9176, 0.9843, 0.9882,  ..., 0.0549, 0.0627, 0.3176]],

         [[0.8431, 0.9647, 0.9882,  ..., 0.6471, 0.9843, 0.9961],
          [0.8510, 0.9686, 0.9922,  ..., 0.6588, 0.9882, 0.9961],
          [0.8353, 0.9647, 0.9922,  ..

+++
tensor([[[[0.4627, 0.4392, 0.4196,  ..., 0.6510, 0.3529, 0.2980],
          [0.3843, 0.4353, 0.4275,  ..., 0.6902, 0.3529, 0.3333],
          [0.3529, 0.3765, 0.3647,  ..., 0.5294, 0.4314, 0.4353],
          ...,
          [0.5451, 0.5451, 0.5608,  ..., 0.5725, 0.5725, 0.5843],
          [0.5451, 0.5451, 0.5647,  ..., 0.5294, 0.5216, 0.5176],
          [0.5490, 0.5490, 0.5137,  ..., 0.5294, 0.5333, 0.5333]],

         [[0.5647, 0.5412, 0.5216,  ..., 0.6784, 0.4157, 0.3882],
          [0.4667, 0.5373, 0.5294,  ..., 0.7137, 0.3961, 0.4039],
          [0.4275, 0.4824, 0.4667,  ..., 0.5647, 0.4824, 0.5137],
          ...,
          [0.6275, 0.6157, 0.6118,  ..., 0.6235, 0.6353, 0.6353],
          [0.6314, 0.6353, 0.6353,  ..., 0.6157, 0.6157, 0.6039],
          [0.6314, 0.6431, 0.5961,  ..., 0.6431, 0.6549, 0.6588]],

         [[0.3255, 0.2588, 0.2314,  ..., 0.6863, 0.3843, 0.2980],
          [0.2745, 0.2627, 0.2471,  ..., 0.7216, 0.4000, 0.3451],
          [0.2627, 0.2157, 0.1922,  ..

+++
tensor([[[[0.8706, 0.8627, 0.8824,  ..., 0.7765, 0.7608, 0.7490],
          [0.8980, 0.8941, 0.9098,  ..., 0.8000, 0.7882, 0.7725],
          [0.7569, 0.7490, 0.7608,  ..., 0.6784, 0.6706, 0.6588],
          ...,
          [0.8000, 0.7961, 0.7961,  ..., 0.7804, 0.7804, 0.7804],
          [0.7843, 0.7843, 0.7922,  ..., 0.7725, 0.7765, 0.7765],
          [0.7843, 0.7804, 0.7882,  ..., 0.7882, 0.7922, 0.7961]],

         [[0.9569, 0.9686, 0.9686,  ..., 0.8706, 0.8745, 0.8706],
          [0.9765, 0.9843, 0.9804,  ..., 0.8824, 0.8902, 0.8863],
          [0.8196, 0.8275, 0.8235,  ..., 0.7608, 0.7647, 0.7608],
          ...,
          [0.9098, 0.9059, 0.9059,  ..., 0.8863, 0.8784, 0.8784],
          [0.8941, 0.8941, 0.9020,  ..., 0.8824, 0.8784, 0.8784],
          [0.8902, 0.8902, 0.8980,  ..., 0.9020, 0.8980, 0.8980]],

         [[0.9765, 0.9725, 0.9686,  ..., 0.9020, 0.9059, 0.9059],
          [1.0000, 0.9922, 0.9843,  ..., 0.9137, 0.9176, 0.9137],
          [0.8471, 0.8392, 0.8275,  ..

+++
tensor([[[[0.9647, 0.9725, 0.9804,  ..., 0.9490, 0.9490, 0.9608],
          [0.9451, 0.9451, 0.9647,  ..., 0.9451, 0.9412, 0.9451],
          [0.9294, 0.9294, 0.9373,  ..., 0.9451, 0.9451, 0.9529],
          ...,
          [0.3176, 0.3176, 0.3137,  ..., 0.2706, 0.2706, 0.2784],
          [0.3176, 0.3176, 0.3176,  ..., 0.3098, 0.3137, 0.3255],
          [0.3176, 0.3176, 0.3098,  ..., 0.3176, 0.3216, 0.3333]],

         [[0.9725, 0.9725, 0.9843,  ..., 0.9412, 0.9412, 0.9569],
          [0.9569, 0.9529, 0.9686,  ..., 0.9373, 0.9333, 0.9373],
          [0.9412, 0.9490, 0.9569,  ..., 0.9373, 0.9373, 0.9451],
          ...,
          [0.3137, 0.3137, 0.3098,  ..., 0.2353, 0.2353, 0.2471],
          [0.3176, 0.3176, 0.3137,  ..., 0.2745, 0.2824, 0.2902],
          [0.3176, 0.3137, 0.3059,  ..., 0.2824, 0.2863, 0.2980]],

         [[0.9961, 1.0000, 1.0000,  ..., 0.9882, 0.9843, 0.9922],
          [0.9882, 0.9843, 0.9961,  ..., 0.9882, 0.9804, 0.9804],
          [0.9843, 0.9765, 0.9804,  ..

+++
tensor([[[[0.4275, 0.4235, 0.4078,  ..., 0.4627, 0.4588, 0.4627],
          [0.4000, 0.3843, 0.3647,  ..., 0.4510, 0.4471, 0.4431],
          [0.4039, 0.3843, 0.3843,  ..., 0.4431, 0.4353, 0.4314],
          ...,
          [0.1176, 0.1176, 0.1216,  ..., 0.1725, 0.1765, 0.1843],
          [0.1255, 0.1255, 0.1255,  ..., 0.1922, 0.1922, 0.1961],
          [0.1373, 0.1373, 0.1333,  ..., 0.2078, 0.2118, 0.2157]],

         [[0.4039, 0.4000, 0.3843,  ..., 0.4235, 0.4196, 0.4235],
          [0.3686, 0.3529, 0.3412,  ..., 0.4157, 0.4118, 0.4078],
          [0.3451, 0.3333, 0.3451,  ..., 0.4078, 0.4039, 0.4000],
          ...,
          [0.0941, 0.0902, 0.0824,  ..., 0.1608, 0.1647, 0.1804],
          [0.1020, 0.0980, 0.0863,  ..., 0.1804, 0.1843, 0.1922],
          [0.1098, 0.1059, 0.0941,  ..., 0.1961, 0.2039, 0.2118]],

         [[0.4431, 0.4431, 0.4392,  ..., 0.4588, 0.4510, 0.4549],
          [0.4078, 0.3961, 0.3922,  ..., 0.4431, 0.4392, 0.4353],
          [0.3765, 0.3647, 0.3725,  ..

+++
tensor([[[[0.7294, 0.6627, 0.7294,  ..., 0.3176, 0.3373, 0.2588],
          [0.5647, 0.5137, 0.5647,  ..., 0.2627, 0.2471, 0.2431],
          [0.6196, 0.6275, 0.6196,  ..., 0.3725, 0.3020, 0.2588],
          ...,
          [0.4353, 0.5059, 0.4353,  ..., 0.5412, 0.5608, 0.5804],
          [0.4078, 0.4588, 0.4078,  ..., 0.4510, 0.4784, 0.4980],
          [0.5647, 0.5255, 0.5647,  ..., 0.4549, 0.4235, 0.4471]],

         [[0.8118, 0.7490, 0.8118,  ..., 0.2863, 0.3608, 0.3020],
          [0.6392, 0.5882, 0.6392,  ..., 0.2588, 0.2667, 0.2745],
          [0.6863, 0.6941, 0.6863,  ..., 0.4392, 0.3490, 0.2863],
          ...,
          [0.5451, 0.6078, 0.5451,  ..., 0.6275, 0.6431, 0.6588],
          [0.5020, 0.5529, 0.5020,  ..., 0.5529, 0.5725, 0.5882],
          [0.6392, 0.6000, 0.6392,  ..., 0.5569, 0.5216, 0.5412]],

         [[0.7922, 0.7373, 0.7922,  ..., 0.2902, 0.3804, 0.3255],
          [0.6275, 0.5843, 0.6275,  ..., 0.2745, 0.2863, 0.2902],
          [0.6627, 0.6706, 0.6627,  ..

+++
tensor([[[[0.1098, 0.1137, 0.1176,  ..., 0.1608, 0.2549, 0.3020],
          [0.1176, 0.1176, 0.1216,  ..., 0.0980, 0.1451, 0.2353],
          [0.1451, 0.1373, 0.1569,  ..., 0.1176, 0.1333, 0.1647],
          ...,
          [0.1333, 0.1412, 0.1529,  ..., 0.1451, 0.1529, 0.1412],
          [0.1608, 0.1647, 0.1608,  ..., 0.1373, 0.1451, 0.1529],
          [0.2000, 0.1686, 0.1490,  ..., 0.1412, 0.1333, 0.1333]],

         [[0.1059, 0.1098, 0.1098,  ..., 0.1686, 0.2392, 0.2627],
          [0.1059, 0.1098, 0.1176,  ..., 0.1255, 0.1412, 0.2000],
          [0.1216, 0.1176, 0.1412,  ..., 0.1255, 0.1137, 0.1255],
          ...,
          [0.1137, 0.1020, 0.1098,  ..., 0.1176, 0.1255, 0.1137],
          [0.1098, 0.1098, 0.1255,  ..., 0.1176, 0.1255, 0.1333],
          [0.1294, 0.1137, 0.1255,  ..., 0.1333, 0.1255, 0.1255]],

         [[0.1098, 0.1137, 0.1216,  ..., 0.1451, 0.2039, 0.2078],
          [0.1098, 0.1176, 0.1294,  ..., 0.1137, 0.1098, 0.1412],
          [0.1294, 0.1255, 0.1412,  ..

+++
tensor([[[[0.0039, 0.0000, 0.0549,  ..., 0.9647, 0.9608, 0.9686],
          [0.0078, 0.0000, 0.0549,  ..., 0.9765, 0.9804, 0.9804],
          [0.0039, 0.0000, 0.0588,  ..., 0.9804, 0.9804, 0.9804],
          ...,
          [0.3922, 0.5294, 0.6118,  ..., 0.0392, 0.0471, 0.0588],
          [0.4745, 0.4000, 0.6235,  ..., 0.0314, 0.0157, 0.0078],
          [0.4980, 0.3961, 0.6431,  ..., 0.0118, 0.0078, 0.0118]],

         [[0.0000, 0.0000, 0.0863,  ..., 0.9922, 0.9922, 0.9961],
          [0.0039, 0.0000, 0.0824,  ..., 0.9961, 0.9961, 0.9961],
          [0.0000, 0.0000, 0.0902,  ..., 0.9922, 0.9961, 0.9961],
          ...,
          [0.4118, 0.5725, 0.6431,  ..., 0.1176, 0.1294, 0.1373],
          [0.4000, 0.4314, 0.6706,  ..., 0.1020, 0.0745, 0.0588],
          [0.3725, 0.3725, 0.6706,  ..., 0.0196, 0.0118, 0.0118]],

         [[0.0000, 0.0000, 0.0549,  ..., 0.9922, 0.9961, 0.9922],
          [0.0078, 0.0039, 0.0627,  ..., 0.9882, 0.9922, 0.9922],
          [0.0039, 0.0000, 0.0745,  ..

+++
tensor([[[[0.8941, 0.9098, 0.9176,  ..., 0.8549, 0.8314, 0.8549],
          [0.8980, 0.9098, 0.9098,  ..., 0.8588, 0.8314, 0.8588],
          [0.8980, 0.9098, 0.9098,  ..., 0.8588, 0.7059, 0.8588],
          ...,
          [0.2706, 0.2627, 0.2588,  ..., 0.5843, 0.5490, 0.5843],
          [0.2627, 0.2706, 0.2667,  ..., 0.5765, 0.5686, 0.5765],
          [0.2941, 0.2627, 0.2353,  ..., 0.5765, 0.5765, 0.5765]],

         [[0.9294, 0.9333, 0.9412,  ..., 0.8745, 0.8549, 0.8745],
          [0.9216, 0.9333, 0.9373,  ..., 0.8706, 0.8431, 0.8706],
          [0.9216, 0.9333, 0.9373,  ..., 0.8667, 0.7176, 0.8667],
          ...,
          [0.2706, 0.2902, 0.2824,  ..., 0.5569, 0.5137, 0.5569],
          [0.2941, 0.3216, 0.3059,  ..., 0.5490, 0.5373, 0.5490],
          [0.3647, 0.3137, 0.2667,  ..., 0.5529, 0.5490, 0.5529]],

         [[0.9647, 0.9647, 0.9569,  ..., 0.9412, 0.9176, 0.9412],
          [0.9686, 0.9686, 0.9647,  ..., 0.9373, 0.9020, 0.9373],
          [0.9686, 0.9725, 0.9647,  ..

+++
tensor([[[[0.3725, 0.3725, 0.3804,  ..., 0.4000, 0.3882, 0.4235],
          [0.3490, 0.3569, 0.3608,  ..., 0.3765, 0.3647, 0.3608],
          [0.3725, 0.3725, 0.3804,  ..., 0.4000, 0.3882, 0.4235],
          ...,
          [0.5725, 0.5412, 0.4941,  ..., 0.4588, 0.4706, 0.4706],
          [0.3765, 0.3412, 0.3333,  ..., 0.4314, 0.4510, 0.4588],
          [0.3137, 0.3176, 0.3216,  ..., 0.4157, 0.4275, 0.4275]],

         [[0.3569, 0.3569, 0.3686,  ..., 0.4000, 0.3843, 0.4196],
          [0.3412, 0.3451, 0.3490,  ..., 0.3725, 0.3608, 0.3569],
          [0.3569, 0.3569, 0.3686,  ..., 0.4000, 0.3843, 0.4196],
          ...,
          [0.5451, 0.5020, 0.4588,  ..., 0.4275, 0.4392, 0.4392],
          [0.3333, 0.2980, 0.2941,  ..., 0.3961, 0.4157, 0.4235],
          [0.2706, 0.2784, 0.2824,  ..., 0.3804, 0.3922, 0.3922]],

         [[0.3882, 0.3882, 0.3922,  ..., 0.4471, 0.4314, 0.4588],
          [0.3922, 0.3961, 0.3961,  ..., 0.4314, 0.4196, 0.4118],
          [0.3882, 0.3882, 0.3922,  ..

KeyboardInterrupt: 