In [None]:
!pip install dm-haiku
!pip install optax

In [None]:
import jax.numpy as jnp
import jax
import haiku as hk
import optax
from typing import NamedTuple, Optional
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import torch

In [None]:
resize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((64, 64))])


train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=resize, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=resize, download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

In [None]:
class ConvVaeEncoder(hk.Module):
    def __init__(self):
        super().__init__()
        self.padding = 'VALID'
        self.fc_size = 32
        self.key = jax.random.PRNGKey(seed=42)

    def __call__(self, x):
        # Input is 64x64x1
        # 1. relu conv, out_channels=32, kernel_size=4, stride=2
        # 2. relu conv, out_channels=64, kernel_size=4, stride=2
        # 3. relu conv, out_channels=128, kernel_size=4, stride=2
        # 4. relu conv, out_channels=256, kernel_size=4, stride=2
        # Latent vector Z_n = 32
        conv1 = hk.Conv2D(32, kernel_shape=4, stride=2, padding=self.padding)(x)
        conv1 = jax.nn.relu(conv1)
        conv2 = hk.Conv2D(64, kernel_shape=4, stride=2, padding=self.padding)(conv1)
        conv2 = jax.nn.relu(conv2)
        conv3 = hk.Conv2D(128, kernel_shape=4, stride=2, padding=self.padding)(conv2)
        conv3 = jax.nn.relu(conv3)
        conv4 = hk.Conv2D(256, kernel_shape=4, stride=2, padding=self.padding)(conv3)
        conv4 = jax.nn.relu(conv4)
        # Flattens on everything except batch dimension
        fc_in = hk.Flatten(preserve_dims=1)(conv4)

        mu = hk.Linear(self.fc_size)(fc_in)
        logsigma = hk.Linear(self.fc_size)(fc_in)
        
        return mu, logsigma

class ConvVaeDecoder(hk.Module):
    def __init__(self):
        super().__init__()
        self.padding = 'VALID'
    
    def __call__(self, x):
        # (H, W, C)
        # Input is 32
        # 1. relu deconv, out_channels=128, kernel_size=5, stride=2
        # 2. relu deconv, out_channels=64, kernel_size=5, stride=2
        # 3. relu deconv, out_channels=32, kernel_size=6, stride=2
        # 4. sigmoid deconv, out_channels=1, kernel_size=6, stride=2
        fc = hk.Linear(1024)(x)
        fc = fc.reshape(-1, 1, 1, 1024)
        conv1 = hk.Conv2DTranspose(128, kernel_shape=5, stride=2, padding=self.padding)(fc)
        conv1 = jax.nn.relu(conv1)
        conv2 = hk.Conv2DTranspose(64, kernel_shape=5, stride=2, padding=self.padding)(conv1)
        conv2 = jax.nn.relu(conv2)
        conv3 = hk.Conv2DTranspose(32, kernel_shape=6, stride=2, padding=self.padding)(conv2)
        conv3 = jax.nn.relu(conv3)
        conv4 = hk.Conv2DTranspose(1, kernel_shape=6, stride=2, padding=self.padding)(conv3)
        conv4 = jax.nn.sigmoid(conv4)
        return conv4

class ConvVAE(hk.Module):
    '''
    Vision model is a variational autoencoder
    '''
    def __init__(self):
        super().__init__()
        self.encoder = ConvVaeEncoder()
        self.decoder = ConvVaeDecoder()

    def __call__(self, x):
        mu, logsigma = self.encoder(x)
        sigma = jnp.exp(logsigma)
        eps = jax.random.normal(jax.random.PRNGKey(42), sigma.shape)
        z = sigma * eps + mu
        decoded = self.decoder(z)
        return mu, logsigma, decoded

In [None]:
class TrainingState(NamedTuple):
    params: hk.Params
    opt_state: optax.OptState

def main():

    def model(x):
        model = ConvVAE()
        mu, std, decoded = model(x)
        return mu, std, decoded

    @jax.jit
    def loss_fn(params, inputs):
        ''' 
        L2 distance between the input image and the reconstruction in addition to KL loss.
        '''
        mu, logsigma, decoded = model.apply(params, inputs)
        l2 = jnp.mean(optax.l2_loss(decoded, inputs))
        kld = -0.5 * jnp.sum(1 + 2 * logsigma - jnp.power(mu, 2) - jnp.exp((2 * logsigma)))
        return l2 + kld, (decoded)
    
    optimizer = optax.adam(3e-4)

    @jax.jit
    def update_weights(state, inputs):
        grad_fn = jax.value_and_grad(loss_fn, argnums=0, has_aux=True)
        (loss, image), grads = grad_fn(state.params, inputs)
        updates, opt_state = optimizer.update(grads, state.opt_state)
        params = optax.apply_updates(state.params, updates)
        return TrainingState(params=params, opt_state=opt_state), (loss, image)

    model = hk.without_apply_rng(hk.transform(model))
    
    for i in train_loader:
        init_image = i[0].numpy()
        B, C, H, W = init_image.shape
        init_image = init_image.reshape(B, H, W, C)
        break
    
    rng_key = jax.random.PRNGKey(42)
    initial_params = model.init(rng_key, init_image)
    initial_opt_state = optimizer.init(initial_params)
    state = TrainingState(initial_params, initial_opt_state)

    for i in range(10):
        loss = 0
        for image, y in train_loader:
            image = image.reshape(B, H, W, C)
            state, (l, image) = update_weights(state, image.numpy())
            print(l)
            loss += l
        print(loss)

    image = init_image[0]
    mu, std, decoded = model.apply(state.params, jnp.expand_dims(image, dim=0))

    plt.figure()
    plt.imshow(jnp.squeeze(decoded))
    plt.show()

In [None]:
main()