In [5]:
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
from flax import linen as nn
from typing import Callable, Any, Optional


In [139]:

class Encoder(nn.Module):

    @nn.compact
    def __call__(self, x):
        
        #0
        x = nn.Conv(512, kernel_size=(1,1),  strides=[1,1], kernel_dilation=2, padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        # x = nn.max_pool(x, window_shape=(2,2), strides=(2,2))

        #1
        x = nn.Conv(512,kernel_size=(1,1), strides=[2,2], kernel_dilation=2, padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        # x = nn.max_pool(x, window_shape=(2,2), strides=(1,1))

        #2
        x = nn.Conv(256,kernel_size=(1,1), strides=[1,1], kernel_dilation=2, padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)
 
        #3
        x = nn.Conv(128,kernel_size=(1,3), strides=[1,1], kernel_dilation=2, padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        
        #4
        x = nn.Conv(64, kernel_size=(1,3), strides=[1,1], kernel_dilation=2, padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        x = nn.max_pool(x, window_shape=(2,2), strides=(2,2))

        #5
        x = nn.Conv(32, kernel_size=(1, 1), kernel_dilation=2, strides=[1,1], padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        
        #6
        x = nn.Conv(16, kernel_size=(1,3), kernel_dilation=2, strides=[1,1],padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        
        #7
        x = nn.Conv(1,kernel_size=(1,3), strides=[1,1], kernel_dilation=2, padding='same')(x)
        x = nn.relu(x)
        x = nn.normalization.BatchNorm(True)(x)

        
        x = x.reshape(x.shape[0], -1)
        
        
        mean_x = nn.Dense(20, name='fc3_mean')(x)
        logvar_x = nn.Dense(20, name='fc3_logvar')(x)
        
        
        return mean_x, logvar_x

In [140]:
enc = Encoder()

In [141]:
rng = jax.random.PRNGKey(303)

In [142]:
x = np.random.randn(16, 48, 1876,1)

In [143]:
init_vari = enc.init(rng, x)

In [144]:
init_vari = enc.init(rng, x)
nn.tabulate(enc, rngs={'params': rng})(x)

'\n\n'

In [145]:
x = np.random.randn(16,20)

In [170]:

class Decoder(nn.Module):
    
    @nn.compact
    def __call__(self, x):
        
        x = nn.Dense(12 * 469 * 1)(x)        
        x = x.reshape(x.shape[0], 12, 469, 1)
        
    
        #0
        
        x = nn.ConvTranspose(32, kernel_size=(1,3), strides=[1,1], kernel_dilation=(2,1))(x)
        x = jax.nn.leaky_relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        
        #1
        
        x = nn.ConvTranspose(64, kernel_size=(1,3), strides=[1,1], kernel_dilation=(2,1))(x)
        x = jax.nn.leaky_relu(x)
        x = nn.normalization.BatchNorm(True)(x)        
        
        #2
        
        x = nn.ConvTranspose(128, kernel_size=(1,3), strides=[2,2], kernel_dilation=(2,1))(x)
        x = jax.nn.leaky_relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        
        #3
        
        x = nn.ConvTranspose(256, kernel_size=(1,3), strides=[2,2], kernel_dilation=(2,1))(x)
        x = jax.nn.leaky_relu(x)
        
        #4
        
        x = nn.ConvTranspose(1, kernel_size=(1,3), strides=[1,1], kernel_dilation=(2,1))(x)
        x = jax.nn.leaky_relu(x)

        return x

In [171]:
dec = Decoder()

In [172]:
initial_vari = dec.init(rng, x = x)
nn.tabulate(dec, rngs={'params': rng})(x)

'\n\n'