In [4]:
from torch import Tensor

import FrEIA.framework as Ff
import FrEIA.modules as Fm

import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from sklearn.datasets import make_moons
from pathlib import Path
from mmvae_hub.utils.fusion_functions import mixture_component_selection_embedding

from torch.distributions import MultivariateNormal
from mmvae_hub.utils.dataclasses.Dataclasses import *
from torch.distributions.normal import Normal

In [5]:
BATCHSIZE = 1000
N_DIM = 2
P = Normal(torch.zeros(N_DIM), torch.ones(N_DIM))

In [6]:
def subnet_fc(dims_in, dims_out):
    return nn.Sequential(nn.Linear(dims_in, 512), nn.ReLU(),
                         nn.Linear(512, dims_out))


In [7]:
# a simple chain of operations is collected by ReversibleSequential
flow = Ff.SequenceINN(N_DIM)
for k in range(8):
    flow.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc, permute_soft=True)

In [8]:
def reparameterize(distr, eps):
    """
    Samples z from a multivariate Gaussian with diagonal covariance matrix using the
     reparameterization trick.
    """

    std = torch.diag(distr.scale_tril).sqrt()
    return eps.mul(std).add_(distr.loc)

In [9]:
enc_mods = {k: Normal(torch.rand(N_DIM), torch.rand(N_DIM)) for k in [1, 2, 3]}




In [19]:
optimizer = torch.optim.Adam(flow.parameters(), lr=0.001)
losses = []
for i in range(1000):
    optimizer.zero_grad()

    samples = {k: enc_mod.sample((100,)) for k, enc_mod in enc_mods.items()}
    tf_enc_mods = {k:flow(sample) for k,sample in samples.items()}
    z_mean = torch.stack([tf_enc_mod[0] for _, tf_enc_mod in tf_enc_mods.items()]).mean(0)
    z, log_jac_det = flow(z_mean, rev=True)

    # interm loss
    interm_loss = torch.stack([enc_mods[k].log_prob(samples[k]) - tf_enc_mods[k][1] - P.log_prob(tf_enc_mods[k][0]) for k in enc_mods]).mean(0)

    Gf = Normal(torch.stack([enc_mod.loc for _,enc_mod in enc_mods.items()]).sum(),  (torch.ones(N_DIM)*3).sqrt())

    # calculate the negative log-likelihood of the model with a standard normal prior
    D_kl = Gf.log_prob(z_mean) + log_jac_det - P.log_prob(z)

    loss = interm_loss + D_kl
    loss = loss.mean() / N_DIM

    # backpropagate and update the weights
    loss.backward()
    optimizer.step()
    losses.append(loss)

    if i % 10 == 0:
        plt.figure()
        plt.subplot(1, 4, 1)
        p = P.sample((BATCHSIZE,))
        z_detached = z.detach().numpy()
        plt.title(f'step {i}')
        plt.scatter(z_detached[:, 0], z_detached[:, 1])
        plt.scatter(p[:, 0], p[:, 1], alpha=0.1)
        plt.axis('off')

        for k,tf_enc_mod in tf_enc_mods.items():


            plt.subplot(1, 4, k+1)
            plt.title(f'z{k}')
            z1_detached = tf_enc_mod[0].detach().numpy()
            plt.scatter(z1_detached[:, 0], z1_detached[:, 1])
            plt.axis('off')



        plt.show()
torch.save(flow.state_dict(), 'flow')

RuntimeError: The size of tensor a (2) must match the size of tensor b (100) at non-singleton dimension 1

In [None]:
plt.plot(losses)
plt.title('Loss')
plt.show()