In [137]:
import os

import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import load_digits
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F



In [138]:
class Digits(Dataset):
    """Scikit-Learn Digits dataset."""

    def __init__(self, mode="train", transforms=None):
        digits = load_digits()
        if mode == "train":
            self.data = digits.data[:1000].astype(np.float32)
        elif mode == "val":
            self.data = digits.data[1000:1350].astype(np.float32)
        else:
            self.data = digits.data[1350:].astype(np.float32)

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample

In [139]:
data = Digits()
data[0].shape
data

<__main__.Digits at 0x12cb7e9d0>

In [140]:
class RoundStraightThrough(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return torch.round(input)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clone()


In [141]:
def log_min_exp(a, b):
    # Numerically stable log(exp(a) - exp(b)).
    max_ab = torch.maximum(a, b)
    min_ab = torch.minimum(a, b)
    return max_ab + torch.log1p(-torch.exp(min_ab - max_ab))

In [142]:
def log_integer_probability(x, mean, logscale):
    logscale = torch.clamp(logscale, -5, 5)
    scale = torch.exp(logscale)
    logp = log_min_exp(
        F.logsigmoid(((x + 0.5) - mean) / scale),
        F.logsigmoid(((x - 0.5) - mean) / scale)
    )
    return logp

In [143]:
D = 64   # input dimension
M = 256  # the number of neurons in scale (s) and translation (t) nets

lr = 1e-3 # learning rate
num_epochs = 100 # max. number of epochs
max_patience = 20 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped
nett = lambda: nn.Sequential(nn.Linear(D // 2, M), nn.LeakyReLU(),
                                     nn.Linear(M, M), nn.LeakyReLU(),
                                     nn.Linear(M, D // 2))
netts = [nett]

In [None]:
class IDF(nn.Module):
    def __init__(self, netts, num_flows, D=2):
        super().__init__()
        if len(netts) == 1:
            self.t = nn.ModuleList([netts[0]() for _ in range(num_flows)])
            self.idf_git = 1

        elif len(netts) == 4:
            self.t_a = nn.ModuleList([netts[0]() for _ in range(num_flows)])
            self.t_b = nn.ModuleList([netts[1]() for _ in range(num_flows)])
            self.t_c = nn.ModuleList([netts[2]() for _ in range(num_flows)])
            self.t_d = nn.ModuleList([netts[3]() for _ in range(num_flows)])
            self.idf_git = 1

        else:
            raise ValueError(f"The transformation net need to be either 1 or 4. The provided net contains {len(netts)} layers.")

        self.num_flows = num_flows
        self.round = RoundStraightThrough.apply

        self.mean = nn.Parameter(torch.zeros(1, D))
        self.logscale = nn.Parameter(torch.ones(1, D))

        self.D = D

    def coupling(self, x, index, forward=True):
        if self.idf_git == 1:
            (xa, xb) = torch.chunk(x, 2, 1)

            if forward:
                yb = xb + self.round(self.t[index](xa))
            else:
                yb = xb - self.round(self.t[index](xa))
            return torch.cat((xa, yb), 1)
        elif self.idf_git == 4:
            (xa, xb, xc, xd) = torch.chunk(x, 4, 1)
            if forward:
                ya = xa + self.round(self.t_a[index](torch.cat((xb, xc, xd), 1)))
                yb = xb + self.round(self.t_b[index](torch.cat((ya, xc, xd), 1)))
                yc = xc + self.round(self.t_c[index](torch.cat((ya, yb, xd), 1)))
                yd = xd + self.round(self.t_d[index](torch.cat((ya, yb, yc), 1)))
            else:
                yd = xd - self.round(self.t_d[index](torch.cat((xa, xb, xc), 1)))
                yc = xc - self.round(self.t_c[index](torch.cat((xa, xb, yd), 1)))
                yb = xb - self.round(self.t_b[index](torch.cat((xa, yc, yd), 1)))
                ya = xa - self.round(self.t_a[index](torch.cat((yb, yc, yd), 1)))

            return torch.cat((ya, yb, yc, yd), 1)

    def permute(self, x):
        return x.flip(1)

    def f(self, x):
        z = x
        for i in range(self.num_flows):
            z = self.coupling(z, i, forward=True)
            z = self.permute(z)
        return z

    def f_inv(self, z):
        x = z
        for i in reversed(range(self.num_flows)):
            x = self.permute(x)
            x = self.coupling(x, i, forward=False)

        return x

    def log_prior(self, x):
        x = torch.clamp(x, -20, 20)
        log_p = log_integer_probability(x, self.mean, self.logscale)
        return log_p.sum(1)

    def forward(self, x, reduction="avg"):
        z = self.f(x)
        if reduction == 'sum':
            return -self.log_prior(z).sum()
        else:
            return -self.log_prior(z).mean()

    def sample(self, batch_size, int_max=100):
        z = self.prior_sample(batch_size=batch_size, D=self.D)
        x = self.f_inv(z)
        return x  # shape: (batch_size, D)

    def prior_sample(self, batch_size, D=2):
        y = torch.rand(batch_size, self.D)
        x = torch.exp(self.logscale) * torch.log(y / (1. - y)) + self.mean
        return torch.round(x)

In [145]:
IDF(netts, 1)

IDF(
  (t): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=32, out_features=256, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=256, out_features=32, bias=True)
    )
  )
)

In [146]:
from torchinfo import summary

num_flows = 8

idf_git = 4

if idf_git == 1:
    nett = lambda: nn.Sequential(
        nn.Linear(D // 2, M),
        nn.LeakyReLU(),
        nn.Linear(M, M),
        nn.LeakyReLU(),
        nn.Linear(M, D // 2),
    )
    netts = [nett]

elif idf_git == 4:
    nett_a = lambda: nn.Sequential(
        nn.Linear(3 * (D // 4), M),
        nn.LeakyReLU(),
        nn.Linear(M, M),
        nn.LeakyReLU(),
        nn.Linear(M, D // 4),
    )

    nett_b = lambda: nn.Sequential(
        nn.Linear(3 * (D // 4), M),
        nn.LeakyReLU(),
        nn.Linear(M, M),
        nn.LeakyReLU(),
        nn.Linear(M, D // 4),
    )

    nett_c = lambda: nn.Sequential(
        nn.Linear(3 * (D // 4), M),
        nn.LeakyReLU(),
        nn.Linear(M, M),
        nn.LeakyReLU(),
        nn.Linear(M, D // 4),
    )

    nett_d = lambda: nn.Sequential(
        nn.Linear(3 * (D // 4), M),
        nn.LeakyReLU(),
        nn.Linear(M, M),
        nn.LeakyReLU(),
        nn.Linear(M, D // 4),
    )

    netts = [nett_a, nett_b, nett_c, nett_d]

model = IDF(netts, num_flows, D=D)

model.idf_git = 4  # ensure 4-way coupling path is used

summary(
    model=model,
    input_size=(1, D),
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
IDF (IDF)                                [1, 64]              --                   128                  True
├─ModuleList (t_a)                       --                   --                   (recursive)          True
│    └─Sequential (0)                    [1, 48]              [1, 16]              --                   True
│    │    └─Linear (0)                   [1, 48]              [1, 256]             12,544               True
│    │    └─LeakyReLU (1)                [1, 256]             [1, 256]             --                   --
│    │    └─Linear (2)                   [1, 256]             [1, 256]             65,792               True
│    │    └─LeakyReLU (3)                [1, 256]             [1, 256]             --                   --
│    │    └─Linear (4)                   [1, 256]             [1, 16]              4,112                True
├─ModuleList (t_b)

In [147]:
def evaluation(test_loader, name=None, model_best=None, epoch=None):
    # EVALUATION
    if model_best is None:
        # load best performing model
        model_best = torch.load(name + ".model")

    model_best.eval()
    loss = 0.0
    N = 0.0
    for indx_batch, test_batch in enumerate(test_loader):
        loss_t = model_best.forward(test_batch, reduction="sum")
        loss = loss + loss_t.item()
        N = N + test_batch.shape[0]
    loss = loss / N

    if epoch is None:
        print(f"FINAL LOSS: nll={loss}")
    else:
        print(f"Epoch: {epoch}, val nll={loss}")

    return loss

In [148]:
from torch.utils.data import DataLoader

train_data = Digits(mode="train")
val_data = Digits(mode="val")
test_data = Digits(mode="test")

training_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=True)

In [149]:
from tqdm.auto import tqdm
def training(name, model, device="cpu"):
    nll_val = []
    best_nll = 1000.0
    patience = 0
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    for epoch in tqdm(range(num_epochs)):
        for x in tqdm(training_loader):
            loss = model.forward(x)
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
        loss_val = evaluation(val_loader, model_best=model, epoch=epoch)
        nll_val.append(loss_val)  # save for plotting
        print(loss_val)
    nll_val = np.asarray(nll_val)

    return nll_val

In [150]:
training(name="IDF", model=model)

100%|██████████| 32/32 [00:01<00:00, 18.23it/s]
  1%|          | 1/100 [00:01<03:04,  1.86s/it]

Epoch: 0, val nll=428.6674776785714
428.6674776785714


100%|██████████| 32/32 [00:01<00:00, 21.64it/s]
  2%|▏         | 2/100 [00:03<02:48,  1.72s/it]

Epoch: 1, val nll=362.94489118303574
362.94489118303574


100%|██████████| 32/32 [00:01<00:00, 18.74it/s]
  3%|▎         | 3/100 [00:05<02:51,  1.77s/it]

Epoch: 2, val nll=327.046875
327.046875


100%|██████████| 32/32 [00:01<00:00, 21.01it/s]
  4%|▍         | 4/100 [00:06<02:44,  1.72s/it]

Epoch: 3, val nll=307.27746651785714
307.27746651785714


100%|██████████| 32/32 [00:01<00:00, 21.68it/s]
  5%|▌         | 5/100 [00:08<02:38,  1.67s/it]

Epoch: 4, val nll=296.2436216517857
296.2436216517857


100%|██████████| 32/32 [00:02<00:00, 12.76it/s]
  6%|▌         | 6/100 [00:11<03:07,  2.00s/it]

Epoch: 5, val nll=289.97143136160713
289.97143136160713


100%|██████████| 32/32 [00:01<00:00, 22.69it/s]
  7%|▋         | 7/100 [00:12<02:50,  1.84s/it]

Epoch: 6, val nll=286.31332310267857
286.31332310267857


100%|██████████| 32/32 [00:01<00:00, 22.15it/s]
  8%|▊         | 8/100 [00:14<02:40,  1.75s/it]

Epoch: 7, val nll=284.10043247767857
284.10043247767857


100%|██████████| 32/32 [00:01<00:00, 23.37it/s]
  9%|▉         | 9/100 [00:15<02:30,  1.66s/it]

Epoch: 8, val nll=282.6889955357143
282.6889955357143


100%|██████████| 32/32 [00:01<00:00, 21.81it/s]
 10%|█         | 10/100 [00:17<02:27,  1.64s/it]

Epoch: 9, val nll=281.71945033482143
281.71945033482143


100%|██████████| 32/32 [00:01<00:00, 19.20it/s]
 11%|█         | 11/100 [00:19<02:30,  1.69s/it]

Epoch: 10, val nll=280.98855747767857
280.98855747767857


100%|██████████| 32/32 [00:01<00:00, 18.99it/s]
 12%|█▏        | 12/100 [00:20<02:31,  1.72s/it]

Epoch: 11, val nll=280.3808286830357
280.3808286830357


100%|██████████| 32/32 [00:01<00:00, 21.17it/s]
 13%|█▎        | 13/100 [00:22<02:27,  1.70s/it]

Epoch: 12, val nll=279.83035435267857
279.83035435267857


100%|██████████| 32/32 [00:01<00:00, 20.09it/s]
 14%|█▍        | 14/100 [00:24<02:26,  1.70s/it]

Epoch: 13, val nll=279.2997739955357
279.2997739955357


100%|██████████| 32/32 [00:01<00:00, 23.72it/s]
 15%|█▌        | 15/100 [00:25<02:20,  1.65s/it]

Epoch: 14, val nll=278.76797154017856
278.76797154017856


100%|██████████| 32/32 [00:01<00:00, 22.54it/s]
 16%|█▌        | 16/100 [00:27<02:15,  1.62s/it]

Epoch: 15, val nll=278.2231752232143
278.2231752232143


100%|██████████| 32/32 [00:01<00:00, 20.75it/s]
 17%|█▋        | 17/100 [00:28<02:15,  1.63s/it]

Epoch: 16, val nll=277.6587220982143
277.6587220982143


100%|██████████| 32/32 [00:01<00:00, 19.55it/s]
 18%|█▊        | 18/100 [00:30<02:16,  1.67s/it]

Epoch: 17, val nll=277.0708314732143
277.0708314732143


100%|██████████| 32/32 [00:01<00:00, 17.14it/s]
 19%|█▉        | 19/100 [00:32<02:24,  1.78s/it]

Epoch: 18, val nll=276.45727957589287
276.45727957589287


100%|██████████| 32/32 [00:02<00:00, 13.59it/s]
 20%|██        | 20/100 [00:35<02:39,  1.99s/it]

Epoch: 19, val nll=275.81640625
275.81640625


100%|██████████| 32/32 [00:01<00:00, 19.36it/s]
 21%|██        | 21/100 [00:37<02:32,  1.93s/it]

Epoch: 20, val nll=275.14706752232144
275.14706752232144


100%|██████████| 32/32 [00:03<00:00, 10.42it/s]
 22%|██▏       | 22/100 [00:40<03:00,  2.32s/it]

Epoch: 21, val nll=274.44811941964286
274.44811941964286


100%|██████████| 32/32 [00:02<00:00, 12.88it/s]
 23%|██▎       | 23/100 [00:42<03:04,  2.40s/it]

Epoch: 22, val nll=273.71853794642857
273.71853794642857


100%|██████████| 32/32 [00:02<00:00, 15.04it/s]
 24%|██▍       | 24/100 [00:45<02:58,  2.35s/it]

Epoch: 23, val nll=272.95727678571427
272.95727678571427


100%|██████████| 32/32 [00:02<00:00, 12.98it/s]
 25%|██▌       | 25/100 [00:47<03:01,  2.42s/it]

Epoch: 24, val nll=272.1630552455357
272.1630552455357


100%|██████████| 32/32 [00:03<00:00, 10.57it/s]
 26%|██▌       | 26/100 [00:50<03:15,  2.65s/it]

Epoch: 25, val nll=271.3345033482143
271.3345033482143


100%|██████████| 32/32 [00:02<00:00, 11.12it/s]
 27%|██▋       | 27/100 [00:53<03:21,  2.75s/it]

Epoch: 26, val nll=270.4702455357143
270.4702455357143


100%|██████████| 32/32 [00:02<00:00, 12.12it/s]
 28%|██▊       | 28/100 [00:56<03:18,  2.76s/it]

Epoch: 27, val nll=269.568818359375
269.568818359375


100%|██████████| 32/32 [00:02<00:00, 12.17it/s]
 29%|██▉       | 29/100 [00:59<03:16,  2.76s/it]

Epoch: 28, val nll=268.628388671875
268.628388671875


100%|██████████| 32/32 [00:02<00:00, 14.59it/s]
 30%|███       | 30/100 [01:01<03:03,  2.63s/it]

Epoch: 29, val nll=267.64718331473216
267.64718331473216


100%|██████████| 32/32 [00:02<00:00, 14.47it/s]
 31%|███       | 31/100 [01:04<02:55,  2.54s/it]

Epoch: 30, val nll=266.623076171875
266.623076171875


100%|██████████| 32/32 [00:03<00:00, 10.03it/s]
 32%|███▏      | 32/100 [01:07<03:09,  2.79s/it]

Epoch: 31, val nll=265.55398856026784
265.55398856026784


100%|██████████| 32/32 [00:03<00:00, 10.22it/s]
 33%|███▎      | 33/100 [01:10<03:15,  2.92s/it]

Epoch: 32, val nll=264.4373758370536
264.4373758370536


100%|██████████| 32/32 [00:02<00:00, 15.11it/s]
 34%|███▍      | 34/100 [01:12<03:00,  2.73s/it]

Epoch: 33, val nll=263.2705970982143
263.2705970982143


100%|██████████| 32/32 [00:02<00:00, 12.82it/s]
 35%|███▌      | 35/100 [01:15<02:55,  2.71s/it]

Epoch: 34, val nll=262.0506570870536
262.0506570870536


100%|██████████| 32/32 [00:02<00:00, 12.47it/s]
 36%|███▌      | 36/100 [01:18<02:52,  2.70s/it]

Epoch: 35, val nll=260.7743833705357
260.7743833705357


100%|██████████| 32/32 [00:02<00:00, 14.37it/s]
 37%|███▋      | 37/100 [01:20<02:43,  2.60s/it]

Epoch: 36, val nll=259.438017578125
259.438017578125


100%|██████████| 32/32 [00:02<00:00, 14.96it/s]
 38%|███▊      | 38/100 [01:22<02:34,  2.50s/it]

Epoch: 37, val nll=258.037626953125
258.037626953125


100%|██████████| 32/32 [00:02<00:00, 15.12it/s]
 39%|███▉      | 39/100 [01:25<02:27,  2.41s/it]

Epoch: 38, val nll=256.56875279017856
256.56875279017856


100%|██████████| 32/32 [00:02<00:00, 15.45it/s]
 40%|████      | 40/100 [01:27<02:20,  2.35s/it]

Epoch: 39, val nll=255.0263671875
255.0263671875


100%|██████████| 32/32 [00:02<00:00, 15.36it/s]
 41%|████      | 41/100 [01:29<02:15,  2.30s/it]

Epoch: 40, val nll=253.404814453125
253.404814453125


100%|██████████| 32/32 [00:02<00:00, 14.55it/s]
 42%|████▏     | 42/100 [01:31<02:13,  2.31s/it]

Epoch: 41, val nll=251.69776925223215
251.69776925223215


100%|██████████| 32/32 [00:02<00:00, 14.32it/s]
 43%|████▎     | 43/100 [01:34<02:12,  2.32s/it]

Epoch: 42, val nll=249.89807338169643
249.89807338169643


100%|██████████| 32/32 [00:02<00:00, 15.11it/s]
 44%|████▍     | 44/100 [01:36<02:08,  2.30s/it]

Epoch: 43, val nll=247.99742047991072
247.99742047991072


100%|██████████| 32/32 [00:02<00:00, 15.16it/s]
 45%|████▌     | 45/100 [01:38<02:05,  2.28s/it]

Epoch: 44, val nll=245.98653878348213
245.98653878348213


100%|██████████| 32/32 [00:04<00:00,  7.79it/s]
 46%|████▌     | 46/100 [01:43<02:37,  2.91s/it]

Epoch: 45, val nll=243.8544921875
243.8544921875


100%|██████████| 32/32 [00:02<00:00, 13.29it/s]
 47%|████▋     | 47/100 [01:45<02:28,  2.80s/it]

Epoch: 46, val nll=241.58872767857142
241.58872767857142


100%|██████████| 32/32 [00:02<00:00, 13.54it/s]
 48%|████▊     | 48/100 [01:48<02:20,  2.71s/it]

Epoch: 47, val nll=239.17452706473213
239.17452706473213


100%|██████████| 32/32 [00:02<00:00, 11.31it/s]
 49%|████▉     | 49/100 [01:51<02:21,  2.78s/it]

Epoch: 48, val nll=236.594462890625
236.594462890625


100%|██████████| 32/32 [00:02<00:00, 13.75it/s]
 50%|█████     | 50/100 [01:53<02:13,  2.68s/it]

Epoch: 49, val nll=233.82784737723213
233.82784737723213


100%|██████████| 32/32 [00:02<00:00, 15.66it/s]
 51%|█████     | 51/100 [01:55<02:03,  2.52s/it]

Epoch: 50, val nll=230.84982003348213
230.84982003348213


100%|██████████| 32/32 [00:01<00:00, 16.34it/s]
 52%|█████▏    | 52/100 [01:57<01:54,  2.39s/it]

Epoch: 51, val nll=227.62999162946429
227.62999162946429


100%|██████████| 32/32 [00:02<00:00, 14.06it/s]
 53%|█████▎    | 53/100 [02:00<01:52,  2.39s/it]

Epoch: 52, val nll=224.13096261160715
224.13096261160715


100%|██████████| 32/32 [00:02<00:00, 14.60it/s]
 54%|█████▍    | 54/100 [02:02<01:48,  2.37s/it]

Epoch: 53, val nll=220.3059095982143
220.3059095982143


100%|██████████| 32/32 [00:02<00:00, 14.21it/s]
 55%|█████▌    | 55/100 [02:04<01:46,  2.37s/it]

Epoch: 54, val nll=216.09510881696428
216.09510881696428


100%|██████████| 32/32 [00:02<00:00, 12.60it/s]
 56%|█████▌    | 56/100 [02:07<01:48,  2.46s/it]

Epoch: 55, val nll=211.42110909598213
211.42110909598213


100%|██████████| 32/32 [00:02<00:00, 13.75it/s]
 57%|█████▋    | 57/100 [02:09<01:45,  2.46s/it]

Epoch: 56, val nll=206.18087472098213
206.18087472098213


100%|██████████| 32/32 [00:02<00:00, 11.90it/s]
 58%|█████▊    | 58/100 [02:12<01:47,  2.56s/it]

Epoch: 57, val nll=200.23370396205357
200.23370396205357


100%|██████████| 32/32 [00:03<00:00,  8.20it/s]
 59%|█████▉    | 59/100 [02:16<02:03,  3.01s/it]

Epoch: 58, val nll=193.38148716517858
193.38148716517858


100%|██████████| 32/32 [00:02<00:00, 11.48it/s]
 60%|██████    | 60/100 [02:19<01:59,  2.98s/it]

Epoch: 59, val nll=185.334248046875
185.334248046875


100%|██████████| 32/32 [00:02<00:00, 15.07it/s]
 61%|██████    | 61/100 [02:21<01:47,  2.76s/it]

Epoch: 60, val nll=175.6484974888393
175.6484974888393


100%|██████████| 32/32 [00:02<00:00, 13.57it/s]
 62%|██████▏   | 62/100 [02:24<01:42,  2.70s/it]

Epoch: 61, val nll=163.61057896205358
163.61057896205358


100%|██████████| 32/32 [00:02<00:00, 13.47it/s]
 63%|██████▎   | 63/100 [02:26<01:37,  2.63s/it]

Epoch: 62, val nll=148.01606863839285
148.01606863839285


100%|██████████| 32/32 [00:01<00:00, 16.39it/s]
 64%|██████▍   | 64/100 [02:28<01:28,  2.46s/it]

Epoch: 63, val nll=126.83507393973214
126.83507393973214


100%|██████████| 32/32 [00:02<00:00, 15.02it/s]
 65%|██████▌   | 65/100 [02:31<01:23,  2.40s/it]

Epoch: 64, val nll=97.70989188058036
97.70989188058036


100%|██████████| 32/32 [00:02<00:00, 14.01it/s]
 66%|██████▌   | 66/100 [02:33<01:21,  2.40s/it]

Epoch: 65, val nll=64.3813166155134
64.3813166155134


100%|██████████| 32/32 [00:02<00:00, 14.80it/s]
 67%|██████▋   | 67/100 [02:35<01:18,  2.37s/it]

Epoch: 66, val nll=37.41371547154018
37.41371547154018


100%|██████████| 32/32 [00:02<00:00, 13.88it/s]
 68%|██████▊   | 68/100 [02:38<01:16,  2.39s/it]

Epoch: 67, val nll=19.577398158482143
19.577398158482143


100%|██████████| 32/32 [00:02<00:00, 13.29it/s]
 69%|██████▉   | 69/100 [02:40<01:15,  2.44s/it]

Epoch: 68, val nll=10.012950875418527
10.012950875418527


100%|██████████| 32/32 [00:02<00:00, 14.72it/s]
 70%|███████   | 70/100 [02:43<01:12,  2.40s/it]

Epoch: 69, val nll=5.463633989606585
5.463633989606585


100%|██████████| 32/32 [00:02<00:00, 14.96it/s]
 71%|███████   | 71/100 [02:45<01:08,  2.36s/it]

Epoch: 70, val nll=3.295044446672712
3.295044446672712


100%|██████████| 32/32 [00:02<00:00, 14.60it/s]
 72%|███████▏  | 72/100 [02:47<01:05,  2.35s/it]

Epoch: 71, val nll=2.1809759085518974
2.1809759085518974


100%|██████████| 32/32 [00:02<00:00, 14.54it/s]
 73%|███████▎  | 73/100 [02:50<01:03,  2.34s/it]

Epoch: 72, val nll=1.5512095860072546
1.5512095860072546


100%|██████████| 32/32 [00:02<00:00, 14.83it/s]
 74%|███████▍  | 74/100 [02:52<01:00,  2.32s/it]

Epoch: 73, val nll=1.1633645302908762
1.1633645302908762


100%|██████████| 32/32 [00:02<00:00, 13.68it/s]
 75%|███████▌  | 75/100 [02:54<00:59,  2.36s/it]

Epoch: 74, val nll=0.9076896885463169
0.9076896885463169


100%|██████████| 32/32 [00:02<00:00, 14.95it/s]
 76%|███████▌  | 76/100 [02:57<00:56,  2.34s/it]

Epoch: 75, val nll=0.7299786758422852
0.7299786758422852


100%|██████████| 32/32 [00:02<00:00, 11.86it/s]
 77%|███████▋  | 77/100 [02:59<00:57,  2.48s/it]

Epoch: 76, val nll=0.6011958040509905
0.6011958040509905


100%|██████████| 32/32 [00:02<00:00, 13.46it/s]
 78%|███████▊  | 78/100 [03:02<00:54,  2.50s/it]

Epoch: 77, val nll=0.5046908814566476
0.5046908814566476


100%|██████████| 32/32 [00:02<00:00, 13.78it/s]
 79%|███████▉  | 79/100 [03:04<00:52,  2.48s/it]

Epoch: 78, val nll=0.43036600657871793
0.43036600657871793


100%|██████████| 32/32 [00:02<00:00, 14.34it/s]
 80%|████████  | 80/100 [03:07<00:48,  2.44s/it]

Epoch: 79, val nll=0.3718054635184152
0.3718054635184152


100%|██████████| 32/32 [00:02<00:00, 13.42it/s]
 81%|████████  | 81/100 [03:09<00:46,  2.46s/it]

Epoch: 80, val nll=0.3247720881870815
0.3247720881870815


100%|██████████| 32/32 [00:01<00:00, 16.19it/s]
 82%|████████▏ | 82/100 [03:11<00:42,  2.35s/it]

Epoch: 81, val nll=0.286373051234654
0.286373051234654


100%|██████████| 32/32 [00:02<00:00, 15.84it/s]
 83%|████████▎ | 83/100 [03:14<00:38,  2.29s/it]

Epoch: 82, val nll=0.25457730974469867
0.25457730974469867


100%|██████████| 32/32 [00:02<00:00, 14.94it/s]
 84%|████████▍ | 84/100 [03:16<00:36,  2.28s/it]

Epoch: 83, val nll=0.2279236125946045
0.2279236125946045


100%|██████████| 32/32 [00:02<00:00, 14.08it/s]
 85%|████████▌ | 85/100 [03:18<00:34,  2.31s/it]

Epoch: 84, val nll=0.20533762523106167
0.20533762523106167


100%|██████████| 32/32 [00:02<00:00, 14.55it/s]
 86%|████████▌ | 86/100 [03:21<00:32,  2.32s/it]

Epoch: 85, val nll=0.18601478712899344
0.18601478712899344


100%|██████████| 32/32 [00:02<00:00, 14.13it/s]
 87%|████████▋ | 87/100 [03:23<00:30,  2.33s/it]

Epoch: 86, val nll=0.169342280796596
0.169342280796596


100%|██████████| 32/32 [00:02<00:00, 14.97it/s]
 88%|████████▊ | 88/100 [03:25<00:27,  2.31s/it]

Epoch: 87, val nll=0.1548464148385184
0.1548464148385184


100%|██████████| 32/32 [00:02<00:00, 14.40it/s]
 89%|████████▉ | 89/100 [03:27<00:25,  2.32s/it]

Epoch: 88, val nll=0.14215619904654367
0.14215619904654367


100%|██████████| 32/32 [00:02<00:00, 13.73it/s]
 90%|█████████ | 90/100 [03:30<00:23,  2.36s/it]

Epoch: 89, val nll=0.13097739219665527
0.13097739219665527


100%|██████████| 32/32 [00:02<00:00, 14.49it/s]
 91%|█████████ | 91/100 [03:32<00:21,  2.36s/it]

Epoch: 90, val nll=0.12107413905007498
0.12107413905007498


100%|██████████| 32/32 [00:02<00:00, 14.48it/s]
 92%|█████████▏| 92/100 [03:35<00:18,  2.37s/it]

Epoch: 91, val nll=0.11225550515311104
0.11225550515311104


100%|██████████| 32/32 [00:02<00:00, 14.95it/s]
 93%|█████████▎| 93/100 [03:37<00:16,  2.34s/it]

Epoch: 92, val nll=0.10436553478240967
0.10436553478240967


100%|██████████| 32/32 [00:02<00:00, 14.27it/s]
 94%|█████████▍| 94/100 [03:39<00:14,  2.34s/it]

Epoch: 93, val nll=0.0972756052017212
0.0972756052017212


100%|██████████| 32/32 [00:02<00:00, 14.58it/s]
 95%|█████████▌| 95/100 [03:42<00:11,  2.33s/it]

Epoch: 94, val nll=0.09087897096361433
0.09087897096361433


100%|██████████| 32/32 [00:01<00:00, 16.12it/s]
 96%|█████████▌| 96/100 [03:44<00:09,  2.27s/it]

Epoch: 95, val nll=0.08508648123059954
0.08508648123059954


100%|██████████| 32/32 [00:02<00:00, 14.71it/s]
 97%|█████████▋| 97/100 [03:46<00:06,  2.29s/it]

Epoch: 96, val nll=0.07982284614018031
0.07982284614018031


100%|██████████| 32/32 [00:02<00:00, 14.01it/s]
 98%|█████████▊| 98/100 [03:49<00:04,  2.33s/it]

Epoch: 97, val nll=0.07502449376242501
0.07502449376242501


100%|██████████| 32/32 [00:02<00:00, 13.82it/s]
 99%|█████████▉| 99/100 [03:51<00:02,  2.36s/it]

Epoch: 98, val nll=0.07063724790300642
0.07063724790300642


100%|██████████| 32/32 [00:02<00:00, 14.75it/s]
100%|██████████| 100/100 [03:53<00:00,  2.34s/it]

Epoch: 99, val nll=0.06661463499069215
0.06661463499069215





array([4.28667478e+02, 3.62944891e+02, 3.27046875e+02, 3.07277467e+02,
       2.96243622e+02, 2.89971431e+02, 2.86313323e+02, 2.84100432e+02,
       2.82688996e+02, 2.81719450e+02, 2.80988557e+02, 2.80380829e+02,
       2.79830354e+02, 2.79299774e+02, 2.78767972e+02, 2.78223175e+02,
       2.77658722e+02, 2.77070831e+02, 2.76457280e+02, 2.75816406e+02,
       2.75147068e+02, 2.74448119e+02, 2.73718538e+02, 2.72957277e+02,
       2.72163055e+02, 2.71334503e+02, 2.70470246e+02, 2.69568818e+02,
       2.68628389e+02, 2.67647183e+02, 2.66623076e+02, 2.65553989e+02,
       2.64437376e+02, 2.63270597e+02, 2.62050657e+02, 2.60774383e+02,
       2.59438018e+02, 2.58037627e+02, 2.56568753e+02, 2.55026367e+02,
       2.53404814e+02, 2.51697769e+02, 2.49898073e+02, 2.47997420e+02,
       2.45986539e+02, 2.43854492e+02, 2.41588728e+02, 2.39174527e+02,
       2.36594463e+02, 2.33827847e+02, 2.30849820e+02, 2.27629992e+02,
       2.24130963e+02, 2.20305910e+02, 2.16095109e+02, 2.11421109e+02,
      

In [152]:

# Generate and visualize samples
model.eval()
with torch.no_grad():
    raw_samples = model.sample(batch_size=16).detach().cpu()
if raw_samples.dim() == 3:
    samples = raw_samples[:, 0, :]
else:
    samples = raw_samples
samples_np = samples.numpy()
fig, axes = plt.subplots(4, 4, figsize=(6, 6))
for ax, img in zip(axes.flat, samples_np):
    ax.imshow(img.reshape(8, 8), cmap="gray")
    ax.axis("off")
plt.tight_layout()
plt.show()


RuntimeError: shape '[16, 2, 64]' is invalid for input of size 1024