In [1]:
import matplotlib.pyplot as plt
import torch
from models_dif import SoftmaxWeight, LocationScaleFlow, DIFDensityEstimator

In [2]:
###MNIST###

import torchvision.datasets as datasets
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
images = mnist_trainset.data.flatten(start_dim=1).float()
temp = (images + torch.rand_like(images))/256

def pre_process(x, lbda):
    return torch.logit(lbda*torch.ones_like(x) + x*(1-2*lbda))

def inverse_pre_process(x, lbda):
    return torch.sigmoid((x- lbda*torch.ones_like(x))/(1-2*lbda))

lbda = 1e-6
target_samples = pre_process(temp, lbda)
p = target_samples.shape[-1]

In [10]:
import torch
from torch import nn

class LocationScaleFlow(nn.Module):
    def __init__(self, K, p):
        super().__init__()
        self.K = K
        self.p = p

        self.m = nn.Parameter(torch.randn(self.K, self.p))
        self.log_s = nn.Parameter(torch.zeros(self.K, self.p))

    def backward(self, z):
        desired_size = list(z.shape)
        desired_size.insert(-1, self.K)
        Z = z.unsqueeze(-2).expand(desired_size)
        return Z * torch.exp(self.log_s).expand_as(Z) + self.m.expand_as(Z)

    def forward(self, x):
        desired_size = list(x.shape)
        desired_size.insert(-1, self.K)
        X = x.unsqueeze(-2).expand(desired_size)
        return (X-self.m.expand_as(X))/torch.exp(self.log_s).expand_as(X)

    def log_det_J(self,x):
        return -self.log_s.sum(-1)

class FullRankLocationScaleFlow(nn.Module):
    def __init__(self, K, p):
        super().__init__()
        self.K = K
        self.p = p

        self.m = nn.Parameter(torch.randn(self.K, self.p))
        self.chol = torch.eye(self.p).unsqueeze(0).repeat(self.K, 1,1)

    def forward(self, z):
        desired_size_Z_M = list(z.shape)
        desired_size_Z_M.insert(-1, self.K)
        desired_size_S = list(z.shape)
        desired_size_S.insert(-1, self.K)
        desired_size_S.insert(-1, self.p)
        return ((self.chol.expand(desired_size_S))@(z.unsqueeze(-2).expand(desired_size_Z_M).unsqueeze(-1))).squeeze(-1) + self.m.expand(desired_size_Z_M)

    def backward(self, x):
        desired_size_X_M = list(x.shape)
        desired_size_X_M.insert(-1, self.K)
        desired_size_S = list(x.shape)
        desired_size_S.insert(-1, self.K)
        desired_size_S.insert(-1, self.p)
        return ((torch.inverse(self.chol).expand(desired_size_S))@((x.unsqueeze(-2).expand(desired_size_X_M)-self.m.expand(desired_size_X_M)).unsqueeze(-1))).squeeze(-1)

    def log_det_J(self):
        S = self.chol
        return torch.log(torch.diagonal(S,0,1,2)).sum(-1)


In [11]:
K = 5
dif = DIFDensityEstimator(target_samples, K)
dif.T  = FullRankLocationScaleFlow(K,p)
dif.T.m = torch.nn.Parameter(torch.mean(target_samples).unsqueeze(0).repeat(K,1) + 0.001*torch.rand(K,p))
#dif.T.chol = torch.cholesky(torch.cov(target_samples.T)).unsqueeze(0).repeat(K,1,1)
print(dif.T.chol.shape)
dif.train(200, 5000)

  0%|                                                                                                                                                             | 0/200 [00:00<?, ?it/s]

torch.Size([5, 784, 784])





RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:76] data. DefaultCPUAllocator: not enough memory: you tried to allocate 61465600000 bytes.

In [None]:
dif.train(2000, 5000)

In [None]:
### Visualize training ###
model_to_visualize = dif

import numpy as np
from matplotlib.ticker import MaxNLocator
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, mark_inset

loss_values = dif.loss_values
best_loss = min(loss_values)
best_iteration = loss_values.index(best_loss)
fig = plt.figure(figsize=(12, 4))
ax = plt.subplot(111)
Y1, Y2 = best_loss - (max(loss_values) -best_loss) / 2, max(loss_values) + (max(loss_values) - best_loss) / 4
ax.set_ylim(Y1, Y2)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.plot(loss_values, label='Loss values during training', color='black')
ax.scatter([best_iteration], [best_loss], color='black', marker='d')
ax.axvline(x=best_iteration, ymax=(best_loss -best_loss + (max(loss_values) - best_loss) / 2) / (
        max(loss_values) + (max(loss_values) - best_loss) / 4 - best_loss + (
        max(loss_values) - best_loss) / 2), color='black', linestyle='--')
ax.text(0, best_loss - (max(loss_values) - best_loss) / 8,
        'best iteration = ' + str(best_iteration) + '\nbest loss = ' + str(np.round(best_loss, 5)),
        verticalalignment='top', horizontalalignment='left', fontsize=12)
if len(loss_values) > 30:
    x1, x2 = best_iteration - int(len(loss_values) / 15), min(best_iteration + int(len(loss_values) / 15),
                                                              len(loss_values) - 1)
    k = len(loss_values) / (2.5 * (x2 - x1 + 1))
    offset = (Y2 - Y1) / (6 * k)
    y1, y2 = best_loss - offset, best_loss + offset
    axins = zoomed_inset_axes(ax, k, loc='upper right')
    axins.axvline(x=best_iteration, ymax=(best_loss - y1) / (y2 - y1), color='black', linestyle='--')
    axins.scatter([best_iteration], [best_loss], color='black', marker='d')
    axins.xaxis.set_major_locator(MaxNLocator(integer=True))
    axins.plot(loss_values, color='black')
    axins.set_xlim(x1 - .5, x2 + .5)
    axins.set_ylim(y1, y2)
    mark_inset(ax, axins, loc1=3, loc2=4)

In [None]:
with torch.no_grad():
    for _ in range(50):
        plt.figure()
        sample = dif.sample_model(1)
        plt.imshow(sample[0].reshape(28,28))