In [107]:

import os
# Ensure matplotlib can write cache files in restricted environments
os.environ.setdefault("MPLCONFIGDIR", "./.matplotlib")
os.makedirs(os.environ["MPLCONFIGDIR"], exist_ok=True)
import matplotlib
matplotlib.use("Agg")

import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import load_digits
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F


In [108]:
torch.cuda.is_available()

True

In [109]:

class Digits(Dataset):

    def __init__(self, mode="train", transforms=None):
        digits = load_digits()
        if mode == "train":
            self.data = digits.data[:1000].astype(np.float32)
        elif mode == "val":
            self.data = digits.data[1000:1350].astype(np.float32)
        else:
            self.data = digits.data[1350:].astype(np.float32)

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = torch.tensor(self.data[idx], dtype=torch.float32)
        if self.transforms:
            sample = self.transforms(sample)
        return sample


In [110]:
train_data = Digits()
test_data = Digits(mode="val")
len(train_data), len(test_data)

(1000, 350)

In [111]:
train_data[0].shape

torch.Size([64])

In [112]:
import matplotlib.pyplot as plt
import random

img = train_data[random.choice(range(len(train_data)))].reshape((8,8))
plt.imshow(img)
plt.axis("off")

(np.float64(-0.5), np.float64(7.5), np.float64(7.5), np.float64(-0.5))

In [113]:

class MoG(nn.Module):
    def __init__(self, D, K, uniform=False):
        super(MoG, self).__init__()

        # Hyperparameters
        self.uniform = uniform
        self.D = D
        self.K = K

        # Parameters
        self.mu = nn.Parameter(torch.randn(1, self.K, self.D) * 0.25 + 0.5)
        self.log_var = nn.Parameter(-3.0 * torch.ones(1, self.K, self.D))

        if self.uniform:
            self.w = torch.zeros(1, self.K)
            self.w.requires_grad = False
        else:
            self.w = nn.Parameter(torch.zeros(1, self.K))
        self.register_buffer("PI", torch.tensor(np.pi))

    def log_diag_normal(self, x, mu, log_var, reduction="sum", dim=1):
        log_p = (
            -0.5 * torch.log(2.0 * self.PI)
            - 0.5 * log_var
            - 0.5 * torch.exp(-log_var) * (x.unsqueeze(1) - mu) ** 2.0
        )
        return log_p

    def forward(self, x, reduction="mean"):
        log_pi = torch.log(
            F.softmax(self.w, 1)
        ) 
        log_N = torch.sum(self.log_diag_normal(x, self.mu, self.log_var), 2)
        NLL_loss = -torch.logsumexp(log_pi + log_N, 1)  # B

        if reduction == "sum":
            return NLL_loss.sum()
        elif reduction == "mean":
            return NLL_loss.mean()
        else:
            raise ValueError("Either `sum` or `mean`.")

    def sample(self, batch_size=64):
        x_sample = torch.empty(batch_size, self.D)

        pi = F.softmax(
            self.w, 1
        )

        indices = torch.multinomial(pi, batch_size, replacement=True).squeeze()

        for n in range(batch_size):
            indx = indices[n]
            x_sample[n] = self.mu[0, indx] + torch.exp(
                0.5 * self.log_var[0, indx]
            ) * torch.randn(self.D)

        return x_sample

    def log_prob(self, x, reduction='mean'):
        with torch.no_grad():
            log_pi = torch.log(F.softmax(self.w, 1)) 
            log_N = torch.sum(self.log_diag_normal(x, self.mu, self.log_var), 2)  

            log_prob = torch.logsumexp(log_pi + log_N,  1) 

            if reduction == 'sum':
                return log_prob.sum()
            elif reduction == 'mean':
                return log_prob.mean()
            else:
                raise ValueError('Either `sum` or `mean`.')


In [114]:

def save_checkpoint(model, path):
    torch.save(
        {
            'state_dict': model.state_dict(),
            'D': model.D,
            'K': model.K,
            'uniform': model.uniform,
        },
        path,
    )


def load_checkpoint(path):
    ckpt = torch.load(path, map_location='cpu')
    if isinstance(ckpt, dict) and 'state_dict' in ckpt:
        model = MoG(D=ckpt['D'], K=ckpt['K'], uniform=ckpt.get('uniform', False))
        model.load_state_dict(ckpt['state_dict'])
        return model
    if isinstance(ckpt, MoG):
        return ckpt
    raise ValueError('Unexpected checkpoint format')


def evaluation(test_loader, name=None, model_best=None, epoch=None):
    if model_best is None:
        model_best = load_checkpoint(name + ".model")

    model_best.eval()
    loss = 0.0
    N = 0.0
    for indx_batch, test_batch in enumerate(test_loader):
        loss_t = -model_best.log_prob(test_batch, reduction="sum")
        loss = loss + loss_t.item()
        N = N + test_batch.shape[0]
    loss = loss / N

    if epoch is None:
        print(f"FINAL LOSS: nll={loss}")
    else:
        print(f"Epoch: {epoch}, val nll={loss}")

    return loss


In [115]:
def samples_real(name, test_loader):
    # REAL-------
    num_x = 4
    num_y = 4
    x = next(iter(test_loader)).detach().numpy()

    fig, ax = plt.subplots(num_x, num_y)
    for i, ax in enumerate(ax.flatten()):
        plottable_image = np.reshape(x[i], (8, 8))
        ax.imshow(plottable_image, cmap="gray")
        ax.axis("off")

    plt.savefig(name + "_real_images.pdf", bbox_inches="tight")
    plt.close()

In [116]:

def samples_generated(name, data_loader, extra_name=""):
    with torch.no_grad():
        # GENERATIONS-------
        model_best = load_checkpoint(name + ".model")

        num_x = 4
        num_y = 4
        x = model_best.sample(batch_size=num_x * num_y)
        x = x.detach().numpy()

        fig, ax = plt.subplots(num_x, num_y)
        for i, ax in enumerate(ax.flatten()):
            plottable_image = np.reshape(x[i], (8, 8))
            ax.imshow(plottable_image, cmap="gray")
            ax.axis("off")

        plt.savefig(
            name + "_generated_images" + extra_name + ".pdf", bbox_inches="tight"
        )
        plt.close()


In [117]:
def plot_curve(name, nll_val):
    plt.plot(np.arange(len(nll_val)), nll_val, linewidth="3")
    plt.xlabel("epochs")
    plt.ylabel("nll")
    plt.savefig(name + "_nll_val_curve.pdf", bbox_inches="tight")
    plt.close()

In [118]:

def means_save(name, extra_name="", num_x=4, num_y=4):
    with torch.no_grad():
        # GENERATIONS-------
        model_best = load_checkpoint(name + ".model")

        pi = F.softmax(model_best.w, 1).squeeze()

        x = model_best.mu[:, 0 : num_x * num_y]
        N = x.shape[1]
        x = x.squeeze(0).detach().numpy()

        fig, ax = plt.subplots(int(np.sqrt(N)), int(np.sqrt(N)))
        for i, ax in enumerate(ax.flatten()):
            plottable_image = np.reshape(x[i], (8, 8))
            ax.imshow(plottable_image, cmap="gray")
            ax.set_title(f"$\pi$ = {pi[i].item():.5f}")
            ax.axis("off")
        fig.tight_layout()
        plt.savefig(name + "_means_images" + extra_name + ".pdf", bbox_inches="tight")
        plt.close()


  ax.set_title(f"$\pi$ = {pi[i].item():.5f}")


In [119]:

def training(
    name, max_patience, num_epochs, model, optimizer, training_loader, val_loader
):
    nll_val = []
    best_nll = 1000.0
    patience = 0

    # Main loop
    for e in range(num_epochs):
        # TRAINING
        model.train()
        for indx_batch, batch in enumerate(training_loader):
            loss = model.forward(batch)

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

        # Validation
        loss_val = evaluation(val_loader, model_best=model, epoch=e)
        nll_val.append(loss_val)  # save for plotting

        if e == 0:
            print("saved!")
            save_checkpoint(model, name + ".model")
            best_nll = loss_val
        else:
            if loss_val < best_nll:
                print("saved!")
                save_checkpoint(model, name + ".model")
                best_nll = loss_val
                patience = 0
            else:
                patience = patience + 1

        samples_generated(name, val_loader, extra_name="_epoch_" + str(e))

        if patience > max_patience:
            break

    nll_val = np.asarray(nll_val)

    return nll_val


In [120]:

def add_noise(x):
    return x / 17.0 + torch.randn_like(x) / 136.0

transforms = add_noise


In [121]:
train_data = Digits(mode="train", transforms=transforms)
val_data = Digits(mode="val", transforms=transforms)
test_data = Digits(mode="test", transforms=transforms)

training_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

In [127]:
D = 64  # input dimension

K = 25  # the number of neurons in scale (s) and translation (t) nets

lr = 1e-3  # learning rate
num_epochs = 100  # max. number of epochs
max_patience = 20  # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped

In [128]:
name = "mog" + "_" + str(K)
if not (os.path.exists("results/")):
    os.mkdir("results")
result_dir = "results/" + name + "/"
if not (os.path.exists(result_dir)):
    os.mkdir(result_dir)

In [129]:
model = MoG(D=D, K=K, uniform=True)

In [130]:
optimizer = torch.optim.AdamW(
    [p for p in model.parameters() if p.requires_grad == True], lr=lr
)

In [131]:
nll_val = training(
    name=result_dir + name,
    max_patience=max_patience,
    num_epochs=num_epochs,
    model=model,
    optimizer=optimizer,
    training_loader=training_loader,
    val_loader=val_loader,
)

Epoch: 0, val nll=70.905517578125
saved!
Epoch: 1, val nll=60.159803292410714
saved!
Epoch: 2, val nll=50.68167759486607
saved!
Epoch: 3, val nll=42.62671282087054
saved!
Epoch: 4, val nll=35.953858816964285
saved!
Epoch: 5, val nll=30.3809619140625
saved!
Epoch: 6, val nll=25.73413312639509
saved!
Epoch: 7, val nll=21.77804443359375
saved!
Epoch: 8, val nll=18.405272391183036
saved!
Epoch: 9, val nll=15.488887677873883
saved!
Epoch: 10, val nll=12.9643603515625
saved!
Epoch: 11, val nll=10.753599679129465
saved!
Epoch: 12, val nll=8.890039934430803
saved!
Epoch: 13, val nll=7.281302664620536
saved!
Epoch: 14, val nll=5.892212088448661
saved!
Epoch: 15, val nll=4.653304661342076
saved!
Epoch: 16, val nll=3.5770477730887276
saved!
Epoch: 17, val nll=2.6272486005510602
saved!
Epoch: 18, val nll=1.7713522488730293
saved!
Epoch: 19, val nll=0.9826360511779785
saved!
Epoch: 20, val nll=0.2492590604509626
saved!
Epoch: 21, val nll=-0.40860964366367886
saved!
Epoch: 22, val nll=-1.03739234651

In [132]:
test_loss = evaluation(name=result_dir + name, test_loader=test_loader)
f = open(result_dir + name + "_test_loss.txt", "w")
f.write(str(test_loss))
f.close()

samples_real(result_dir + name, test_loader)
samples_generated(result_dir + name, test_loader, extra_name="FINAL")

means_save(result_dir + name, extra_name="_" + str(K), num_x=5, num_y=5)

plot_curve(result_dir + name, nll_val)

FINAL LOSS: nll=-24.652738771715953
