In [120]:
import jax
import jax.numpy as jnp
from jax import random

from flax import linen as nn
from typing import Callable, Any, Optional


In [121]:

class Encoder(nn.Module):

    @nn.compact
    def __call__(self, x):
        
        x = nn.Conv(512, kernel_size=(3,3),  strides=[2,2], padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        # x = nn.max_pool(x, window_shape=(2,2), strides=(2,2))

        
        x = nn.Conv(512,kernel_size=(3,3), strides=[2,2], padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        x = nn.max_pool(x, window_shape=(2,2), strides=(2,2))

        
        x = nn.Conv(256,kernel_size=(3,3), strides=[2,2],padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)
 

        x = nn.Conv(128,kernel_size=(3,3), padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        
        x = nn.Conv(64,kernel_size=(3,3), strides=[2,2], padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        x = nn.max_pool(x, window_shape=(2,2), strides=(2,2))

        
        x = nn.Conv(32, kernel_size=(3, 3),  padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        
        
        x = nn.Conv(16, kernel_size=(2,2), strides=[1,1], padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        
        x = nn.Conv(1,kernel_size=(2,2), strides=[1,1],  padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)

        
        x = x.reshape(x.shape[0], -1)
        
        
        mean_x = nn.Dense(20, name='fc3_mean')(x)
        logvar_x = nn.Dense(20, name='fc3_logvar')(x)
        
        
        return mean_x, logvar_x

In [122]:
enc = Encoder()

In [123]:
rng = jax.random.PRNGKey(303)

In [124]:
import numpy as np

In [213]:
x = np.random.randn(16, 48, 1876,1)

In [214]:
init_vari = enc.init(rng, x)

In [215]:
init_vari = enc.init(rng, x)
nn.tabulate(enc, rngs={'params': rng})(x)

'\n\n'

In [216]:

class Encoder(nn.Module):

    @nn.compact
    def __call__(self, x):
        
        x = nn.Conv(512, kernel_size=(3,3),  strides=[2,2], padding='same')(x)
        x = nn.relu(x)
        # x = nn.normalization.BatchNorm(True)(x)
        # x = nn.max_pool(x, window_shape=(2,2), strides=(2,2))

        
        x = nn.Conv(512,kernel_size=(3,3), strides=[2,2], padding='same')(x)
        x = nn.relu(x)
        # x = nn.normalization.BatchNorm(True)(x)
        # x = nn.max_pool(x, window_shape=(2,2), strides=(2,2))

        
        x = nn.Conv(256,kernel_size=(3,3), strides=[2,2],padding='same')(x)
        x = nn.relu(x)
        # x = nn.normalization.BatchNorm(True)(x)
 

        x = nn.Conv(128,kernel_size=(3,3), padding='same')(x)
        x = nn.relu(x)
        # x = nn.normalization.BatchNorm(True)(x)
        
        x = nn.Conv(64,kernel_size=(3,3), strides=[2,2], padding='same')(x)
        x = nn.relu(x)
        # x = nn.normalization.BatchNorm(True)(x)
        # x = nn.max_pool(x, window_shape=(2,2), strides=(2,2))

        
        x = nn.Conv(32, kernel_size=(3, 3),  padding='same')(x)
        x = nn.relu(x)
        # x = nn.normalization.BatchNorm(True)(x)
        
        
        x = nn.Conv(16, kernel_size=(2,2), strides=[2,2], padding='same')(x)
        x = nn.relu(x)
        # x = nn.normalization.BatchNorm(True)(x)
        
        x = nn.Conv(8,kernel_size=(2,2), strides=[2,2],  padding='same')(x)
        x = nn.relu(x)
        # x = nn.normalization.BatchNorm(True)(x)

        
        x = x.reshape(x.shape[0], -1)
        
        
        mean_x = nn.Dense(20, name='fc3_mean')(x)
        logvar_x = nn.Dense(20, name='fc3_logvar')(x)
        
        
        return mean_x, logvar_x

In [217]:
enc = Encoder()

In [218]:
init_vari = enc.init(rng, x)
nn.tabulate(enc, rngs={'params': rng})(x)

'\n\n'

In [159]:
x = np.random.randn(16,20)

In [202]:
class Decoder(nn.Module):
    
    recon_shape: int = 1876
    
    @nn.compact
    def __call__(self, x):
        
        x = nn.Dense(2 * 59)(x)        
        x = x.reshape(x.shape[0], 2, 59, 1)
        
        x = nn.ConvTranspose(64, kernel_size=(1,3), padding='valid')(x)
        x = nn.relu(x)
        # x = nn.normalization.BatchNorm(True)(x)
        
        x = nn.ConvTranspose(128, kernel_size=(3,3), strides=[2,2])(x)
        x = nn.relu(x)
        # x = nn.normalization.BatchNorm(True)(x)        
        
        x = nn.ConvTranspose(128, kernel_size=(3,3), strides=[2,2])(x)
        x = nn.relu(x)
        # x = nn.normalization.BatchNorm(True)(x)
        
        x = nn.ConvTranspose(256, kernel_size=(3,3), strides=[2,2])(x)
        x = nn.relu(x)
        # x = nn.normalization.BatchNorm(True)(x)
        
        x = nn.ConvTranspose(256, kernel_size=(3,3), strides=[2,2])(x)
        x = nn.relu(x)
        # x = nn.normalization.BatchNorm(True)(x)
        
        x = nn.ConvTranspose(512, kernel_size=(3,3))(x)
        x = nn.relu(x)
        
        x = nn.ConvTranspose(1, kernel_size=(3,3), strides=[2,2])(x)
        x = nn.relu(x)

        return x

In [203]:
dec = Decoder()

In [204]:
x.shape

(16, 20)

In [205]:
initial_vari = dec.init(rng, x = x)

In [206]:
nn.tabulate(dec, rngs={'params': rng})(x)

'\n\n'