# 変分オートエンコーダー（VAE）

ここでは jax/flax を用いて変分オートエンコーダー（VAE）を作成する。

## Imports

In [10]:
from pprint import pprint

import numpy as np

import jax
import flax
import optax
from jax import numpy as jnp
from jax import random
from flax import linen as nn

In [25]:
#@title Model Definition

class Encoder(nn.Module):
    hidden_dim: int = 512
    latent_dim: int = 32

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim, name='Dense1')(x)
        x = nn.relu(x)
        mu = nn.Dense(self.latent_dim, name='Dense_mu')(x)
        logvar = nn.Dense(self.latent_dim, name='Dense_logvar')(x)
        sigma = jnp.exp(0.5 * logvar)
        return mu, sigma


class Decoder(nn.Module):
    hidden_dim: int = 512
    output_dim: int = 32

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim, name='Dense1')(x)
        x = nn.relu(x)
        x = nn.Dense(self.output_dim, name='Dense2')(x)
        x = nn.sigmoid(x)
        return x


def reparameterize(mu, sigma):
    eps = np.random.randn(*sigma.shape)
    x = mu + eps * sigma
    return x


class VAE(nn.Module):
    @nn.compact
    def __call__(self, x):
        mu, sigma = Encoder()(x)
        x = reparameterize(mu, sigma)
        x = Decoder()(x)
        return x

In [26]:
model = VAE()

key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10,)) # Dummy input data
params = model.init(key2, x)   # Initialization call
pprint(jax.tree_util.tree_map(lambda x: x.shape, params))  # Checking output shapes

{'params': {'Decoder_0': {'Dense1': {'bias': (512,), 'kernel': (32, 512)},
                          'Dense2': {'bias': (32,), 'kernel': (512, 32)}},
            'Encoder_0': {'Dense1': {'bias': (512,), 'kernel': (10, 512)},
                          'Dense_logvar': {'bias': (32,), 'kernel': (512, 32)},
                          'Dense_mu': {'bias': (32,), 'kernel': (512, 32)}}}}


In [27]:
output = model.apply(params, x)
jax.tree_util.tree_map(jnp.shape, output)

(32,)