<a href="https://colab.research.google.com/github/MaxW-M/Fragment-VAE/blob/main/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title Installs

!pip install rdkit
!tar -xzvf openbabel-3-1-0.tar.gz
!cd openbabel-openbabel-3-1-0
!python setup.py install

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting rdkit
  Downloading rdkit-2023.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.7/29.7 MB[0m [31m35.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rdkit
Successfully installed rdkit-2023.3.2
tar (child): openbabel-3-1-0.tar.gz: Cannot open: No such file or directory
tar (child): Error is not recoverable: exiting now
tar: Child returned status 2
tar: Error is not recoverable: exiting now
/bin/bash: line 0: cd: openbabel-openbabel-3-1-0: No such file or directory
python3: can't open file '/content/setup.py': [Errno 2] No such file or directory


In [2]:
# @title imports

from keras.callbacks import *
import pandas as pd
import argparse
import os
import numpy as np
from keras.layers import Input, Conv1D, Embedding, Flatten, Dense, Reshape, Lambda, Activation, Dropout, BatchNormalization
from keras.models import Model
from keras import backend as K
from keras.optimizers import Adam
from keras.metrics import categorical_crossentropy
from os.path import *
import multiprocessing as mp
import re
import glob
#from openbabel import pybel
import ast
from rdkit import Chem
from functools import partial
import keras
import warnings
import unittest
from numpy.testing import assert_array_equal
from tensorflow.python.framework.ops import disable_eager_execution

disable_eager_execution()

In [3]:
# @title Preprocess

# the only characters that are allowed to be in a SMILES string for our purposes
SMI = ['PAD',
       '1', '2', '3', '4', '5', '6', '7', '8', '9', '0',
       'c', 'h', 'C', 'H', 'N', 'O', 'P', 'S', 'F', 'C', 'I', 'B',
       '/', '-', '(', ')', ',',
       'n', 'o', 'p', 's','f', 'c', 'l', 'b', 'r', '.', '=',
       '#', '$', ':', '\\', '@', '+', '[', ']'  # add'l smi chars
       ]


'''
preconditions: struct is a SMILES string
               charset is a list of characters
               max_length is the maximum size string allowed for enconding
postconditions: returns struct as a vector (an array of ints) or an array of zeros
                if output is None
'''
def struct2vec(struct, charset=SMI, max_length=100):
    if pd.isna(struct):
      vec = np.array([])
    else:
      vec = np.zeros(len(struct))
      for i in range(len(struct)):
        s = struct[i]
        if s in charset:
            vec[i] = charset.index(s)
        else:  # Illegal character in s
            return np.zeros(max_length)

    if len(vec) > max_length:
      return np.zeros(max_length)
    else:
      vec = np.append(vec, np.zeros(max_length - vec.size))
      return np.array(vec, dtype=np.uint8)


'''
preconditions: smiles is an array of SMILES strings
postcondition: converts the SMILES strings to vectors and returns them as an array
'''
def vectorize(smiles, processes=mp.cpu_count()):
    with mp.Pool(processes=processes) as p:
        return p.map(struct2vec, smiles)


'''
preconditions: df is a dataframe with a canonical SMILES column
               name is a string
               output is an address to an output file
               shuffle is a boolean
postconditions: adds a vec column to df consisting of its vectorized SMILES column
'''
def process(df, name, output, shuffle=False):
    # shuffle data
    if shuffle:
        df = df.sample(frac=1).reset_index(drop=True)

    # already converted
    if 'SMILES' in df.columns:
        pass
    # error
    else:
        raise KeyError('Dataframe must have a "SMILES" column.')

    # turn SMILES strings into vectors
    vectors = np.vstack(vectorize(df['SMILES'].values))

    # set vectors consisting of only 0s to NaN
    vectors = np.where(np.all(vectors == 0, axis=1, keepdims=True), np.nan, vectors)

    # add a vec column to df
    df['vec'] = vectors.tolist()

    # drop any row that has a NaN value in one of its columns
    df.dropna(how='any', axis=0, inplace=True)

    arr = np.vstack(df['vec'].values)
    labels = df.drop(columns=['SMILES', 'vec'])

    # save
    np.save(os.path.join(output, '%s.npy' % name), arr)

    if len(labels.columns) > 0:
        np.save(os.path.join(output, '%s_labels.npy' % name), labels.values)




In [4]:
# @title Fragment Preprocessing

input_data = ['SMILES', '1-octanol', 'butyl acetate', 'chloroform', 'cyclohexane',
              'dichloromethane', 'triolein', 'n-hexane', 'n-octane', 'oleylalcohol',
              'toluene', 'n-undecane', 'MONOISOTOPIC_MASS', 'H', 'C', 'N', 'O',
              'F', 'P', 'S', 'Cl', 'I', 'Br']

'''

'i_Al_COO','i_Al_OH','i_Al_OH_noTert','i_ArN',
              'i_Ar_COO','i_Ar_N','i_Ar_NH','i_Ar_OH','i_COO','i_COO2','i_C_O','i_C_O_noCOO','i_C_S',
              'i_HOCCN','i_Imine','i_NH0','i_NH1','i_NH2','i_N_O','i_Ndealkylation1','i_Ndealkylation2',
              'i_Nhpyrrole','i_SH','i_aldehyde','i_alkyl_carbamate','i_alkyl_halide','i_allylic_oxid',
              'i_amide','i_amidine','i_aniline','i_aryl_methyl','i_azide','i_azo','i_barbitur','i_benzene',
              'i_benzodiazepine','i_bicyclic','i_diazo','i_dihydropyridine','i_epoxide','i_ester','i_ether',
              'i_furan','i_guanido','i_halogen','i_hdrzine','i_hdrzone','i_imidazole','i_imide','i_isocyan',
              'i_isothiocyan','i_ketone','i_ketone_Topliss','i_lactam','i_lactone','i_methoxy','i_morpholine',
              'i_nitrile','i_nitro','i_nitro_arom','i_nitro_arom_nonortho','i_nitroso','i_oxazole','i_oxime',
              'i_para_hydroxylation','i_phenol','i_phenol_noOrthoHbond','i_phos_acid','i_phos_ester','i_piperdine',
              'i_piperzine','i_priamide','i_prisulfonamd','i_pyridine','i_quatN','i_sulfide','i_sulfonamd','i_sulfone',
              'i_term_acetylene','i_tetrazole','i_thiazole','i_thiocyan','i_thiophene','i_unbrch_alkane','i_urea'


preconditions: data is the path to the data file
               output is the path to the file where the trimmed csv will be saved
               set shuffle to True if you want the data shuffled
postconditions: returns input and output data as x and y respectively
'''
def process_fragments(data, output, shuffle=False):
    # read csv file with the data
    df = pd.read_csv(data)

    # remove any non-input or SMILES columns
    for col in df.columns:
      if col not in input_data:
        df.drop(columns=col, inplace=True)

    # shuffle the data
    if shuffle:
      df = df.sample(frac=1).reset_index(drop=True)

    # save the trimmed data frame
    df.to_csv(output, index=False)

    # split the dataframe into x and y values
    y_df = df
    y_df = y_df[['SMILES']]
    df.drop(columns='SMILES', inplace=True)

    # turn x into a numpy array
    x = df.to_numpy()

    return x, y_df

In [49]:
from keras.activations import linear
# @title Fragment VAE

class fragment_VAE():

    def create(self, max_length, nchars, inputdim, **kwargs):
      self.max_length = max_length
      self.nchars = nchars
      self.inputdim = inputdim

      self.create_encoder()
      self.create_decoder()
      self.autoencoder = Model(inputs=self.x,
                                 outputs=self.decoder(self.encoder(self.x)),
                                 name='vae')

      #self.decoder.trainable=False
      #self.decoder.trainable = False
      opt = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=1E-8, amsgrad=True)
      self.autoencoder.compile(optimizer=opt,
                                 #loss='mean_squared_error',
                                 loss=categorical_crossentropy,
                                 metrics=['accuracy', categorical_crossentropy, ignore_accuracy])

    def create_encoder(self):
      self.x = Input(shape=self.inputdim)
      self.encoder = keras.Sequential([Input(shape=self.inputdim),
                                       BatchNormalization(),
                                       Dropout(0.2),
                                       #Dense(292, activation='relu'),
                                       #Dropout(0.2),
                                       Dense(292, activation='linear')],
                                      name='encoder')

    def create_decoder(self):
        # define latent input
        encoded_input = Input(shape=(292,))

        # connect to latent dim
        h = Reshape((292, 1))(encoded_input)

        # build filters
        for f, k in zip([10, 10, 11], [9, 9, 10]):
            h = Conv1D(f, k, activation='relu', padding='same')(h)

        # prepare output dim
        h = Flatten()(h)
        h = Dense(self.max_length * self.nchars)(h)
        h = Reshape((self.max_length, self.nchars))(h)
        decoded = Activation('softmax')(h)

        # construct decoder (latent->output)
        self.decoder = Model(inputs=encoded_input,
                             outputs=decoded,
                             name='decoder')



'''
a custom accuracy metric that ignores padded characters
'''

def ignore_accuracy(y_true, y_pred):
        y_true_class = K.argmax(y_true, axis=-1)
        y_pred_class = K.argmax(y_pred, axis=-1)

        ignore_mask = K.cast(K.not_equal(y_pred_class, 0), 'int32')
        matches = K.cast(K.equal(y_true_class, y_pred_class), 'int32') * ignore_mask
        accuracy = K.sum(matches) / K.maximum(K.sum(ignore_mask), 1)
        return accuracy

In [68]:
# @title Fragment Training

'''
preconditions: data is the path to the input data
               output is the path to the output folder
               validation is a float that represents the percentage of SMILES strings to be withheld for validation
               kernels is a list of ints, the kernel size for each convolution
               filters is a list of ints, the number of filters per convolution, must be the same size as kernels
               embedding_dim is the input vector of the embedding dimension
               latent_dim is the dimension of the latent encoding space
               epsilon is the standard deviation of the latent space
               freeze_vae freezes the autoencoder weights
               weights is the path to a directory containing pretrained weights
               patience is the early stopping patience, an int
               batch size is the size of training batches, an int
               epochs is the number of training epochs

postconditions: creates, trains, and saves the model
'''
def train_fragments(data, output, inputdim=22, validation=0.1, kernels=[9, 9, 10], filters=[10, 10, 11], embedding_dim=32, latent_dim=292,
          epsilon=0.1, freeze_vae=False, weights=None, patience=5, batch_size=128, epochs=10):
    # load data
    x, y_df = process_fragments(data, 'trimmed_data.csv')

    process(y_df, 'fragment_y', 'processed_data')
    y = np.vstack(y_df['vec'].values).astype(np.uint8)
    n, m = y.shape

    d = max(np.unique(y)) + 1

    # test/train split
    mask = test_train_split(x, validation)
    x_train = x[mask]
    x_validation = x[~mask]
    y_train = y[mask]
    y_validation = y[~mask]


    max_length = m
    nchars = d

    args = {'which':'train',
            'nchars':d,
            'max_length':m,
            'inputdim':inputdim,
            'data':data,
            'output':output,
            'labels':None,
            'kernels':kernels,
            'filters':filters,
            'embedding_dim':embedding_dim,
            'latent_dim':latent_dim,
            'epsilon':epsilon,
            'dropout':0.2,
            'weights':None,
            'freeze_vae':False,
            'validation':0.1,
            'batch_size':batch_size,
            'epochs':epochs,
            'patience':patience}

    savedict(args, output, verbose=False)

    #TODO: remove
    print(args)

    # one-hot encode targets
    y_validation = keras.utils.to_categorical(y_validation, d)
    y_validation = y_validation.reshape((-1, m, d))

    y_train = keras.utils.to_categorical(y_train, d)
    y_train = y_train.reshape((-1, m, d))

    # initialize autoencoder
    model = fragment_VAE()
    model.create(max_length, nchars, inputdim)

    # model checkpointing
    models = [model.autoencoder, model.encoder, model.decoder]
    filepaths = [os.path.join(output, f) for f in ('vae.h5',
                                                        'encoder.h5',
                                                        'decoder.h5',)]
    checkpoint = MultiModelCheckpoint(models, filepaths, monitor='val_loss',
                                      save_best_only=True, mode='min', save_weights_only=True)

    # print model summary
    print(model.encoder.summary())
    print(model.decoder.summary())
    print(model.autoencoder.summary())


    # optionally load weights
    if weights is not None:
        model.encoder_variational.load_weights(os.path.join(weights, 'encoder+v.h5'))
        model.decoder.load_weights(os.path.join(weights, 'decoder.h5'))

    # early stopping
    early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=patience, mode='min')

    # history
    history = LossHistory(os.path.join(output, 'loss_history.tsv'))

    # train vae
    model.autoencoder.fit(x_train, y_train,
                          batch_size=batch_size,
                          epochs=epochs,
                          validation_data=(x_validation, y_validation),
                          callbacks=[early_stop, checkpoint, history],
                          shuffle=True,
                          verbose=2)


'''
preconditions: x is a numpy array of data
               test_size is the percentage of data to be witheld for testing
postconditions: returns a random list of ones and zeros that indicate which data points
                should be used for testing and which data points should be used for training
'''
def test_train_split(x, test_size=0.1):

    idx = np.random.choice(np.arange(len(x)), size=int(len(x) * test_size), replace=False)
    mask = np.ones(len(x)).astype('bool')
    mask[idx] = False

    return mask

'''
preconditions: d is a dictionary
               path is a path to an output file
               verbose is a boolean
postconditions: saves d to a file, used to save the variables used in the
                creation of a model
'''
def savedict(d, path, verbose=True):
    if verbose:
        print('Arguments:')
    with open(os.path.join(path, 'arguments.txt'), 'w') as f:
        for k, v in d.items():
            if v is None:
                f.write("%s: %s\n" % (k, '-1'))
            else:
                f.write("%s: %s\n" % (k, v))
            if verbose:
                print("\t%s: %s" % (k, v))


class MultiModelCheckpoint(Callback):
    """Save the model after every epoch.
    `filepath` can contain named formatting options,
    which will be filled the value of `epoch` and
    keys in `logs` (passed in `on_epoch_end`).
    For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
    then the model checkpoints will be saved with the epoch number and
    the validation loss in the filename.
    # Arguments
        filepath: string, path to save the model file.
        monitor: quantity to monitor.
        verbose: verbosity mode, 0 or 1.
        save_best_only: if `save_best_only=True`,
            the latest best model according to
            the quantity monitored will not be overwritten.
        mode: one of {auto, min, max}.
            If `save_best_only=True`, the decision
            to overwrite the current save file is made
            based on either the maximization or the
            minimization of the monitored quantity. For `val_acc`,
            this should be `max`, for `val_loss` this should
            be `min`, etc. In `auto` mode, the direction is
            automatically inferred from the name of the monitored quantity.
        save_weights_only: if True, then only the model's weights will be
            saved (`model.save_weights(filepath)`), else the full model
            is saved (`model.save(filepath)`).
        period: Interval (number of epochs) between checkpoints.
    """

    def __init__(self, models, filepaths, monitor='val_loss',
                 save_best_only=False, save_weights_only=False,
                 mode='auto', period=1):
        super(MultiModelCheckpoint, self).__init__()
        self.models = models
        self.monitor = monitor
        self.filepaths = filepaths
        self.save_best_only = save_best_only
        self.save_weights_only = save_weights_only
        self.period = period
        self.epochs_since_last_save = 0

        if mode not in ['auto', 'min', 'max']:
            warnings.warn('ModelCheckpoint mode %s is unknown, '
                          'fallback to auto mode.' % (mode),
                          RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
            self.best = np.Inf
        elif mode == 'max':
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = np.greater
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf

    '''
    preconditions: epoch is the current training epoch
                   logs is a log of the monitored values
    postconditions: at the end of each epoch if the epoch is a save epoch
                    then save the model
    '''
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.epochs_since_last_save += 1
        if self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
            if self.save_best_only:
                current = logs.get(self.monitor)
                if current is None:
                    warnings.warn('Can save best model only with %s available, '
                                  'skipping.' % (self.monitor), RuntimeWarning)
                else:
                    if self.monitor_op(current, self.best):
                        self.best = current
                        for m, p in zip(self.models, self.filepaths):
                            if self.save_weights_only:
                                m.save_weights(p, overwrite=True)
                            else:
                                m.save(p, overwrite=True)
            else:
                for m, p in zip(self.models, self.filepaths):
                    if self.save_weights_only:
                        m.save_weights(p, overwrite=True)
                    else:
                        m.save(p, overwrite=True)

'''
Saves the loss at the end of each epoch
'''
class LossHistory(Callback):
    def __init__(self, path):
        self.path = path

    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs)

    def on_epoch_end(self, epoch, logs=None):
        self.save()

    def save(self):
        pd.DataFrame(self.losses).to_csv(self.path, index=False, sep='\t')

In [75]:
# @title Predict SMILES from Fragments
def predict(data, network):
    df = pd.read_csv(data)

    name = 'Predicted SMILES'

    # split the dataframe into x and y values
    df.drop(columns='SMILES', inplace=True)

    # load model
    config = load_config(join(network, 'arguments.txt'))
    config['output'] = network
    model = fragment_model_from_config(config)

    print(model.encoder.summary())
    print(model.decoder.summary())
    print(model.autoencoder.summary())


    # predict latent
    latent = model.encoder.predict(df)

    # softmax
    softmax = model.decoder.predict(latent)

    # argmax and convert to smiles
    smiles_out = np.array([vec2struct(x) for x in np.argmax(softmax, axis=-1)])

    return smiles_out

def fragment_model_from_config(config):
    #TODO: Remove
    print(config)

    # initialize autoencoder
    model = fragment_VAE()
    model.create(**config)

    # load weights
    model.encoder.load_weights(os.path.join(config['output'], 'encoder.h5'))
    if os.path.exists(os.path.join(config['output'], 'decoder.h5')):
        model.decoder.load_weights(os.path.join(config['output'], 'decoder.h5'))

    return model

def load_config(filepath):
    config = {}
    with open(filepath) as f:
        for line in f:
            #TODO: remove
            print(line)

            (key, val) = [x.strip() for x in line.split(':')]
            if val == '-1':
                config[key] = None
            elif key in ['nchars', 'max_length', 'embedding_dim', 'nlabels','inputdim',
                         'latent_dim', 'batch_size', 'epochs', 'patience', 'seed']:
                config[key] = int(val)
            elif key in ['kernels', 'filters', 'freeze_vae']:
                config[key] = ast.literal_eval(val)
            elif key in ['epsilon', 'dropout', 'validation']:
                config[key] = float(val)
            elif key in ['data', 'output', 'weights', 'labels']:
                config[key] = val

    config['epsilon_std'] = config.pop('epsilon')
    config['output'] = dirname(filepath)
    return config

def vec2struct(vec, charset=SMI):
    '''
    Decodes a structure using the given charset.

    Parameters
    -------
    vec : unit8 array
        Encoded structure

    charset : list, optional
        Character set used for encoding.

    Returns
    -------
    struct : str
        Structure of compound, represented as an InChI or SMILES string.
        Note: for InChIs, no layer past the hydrogen layer will be
        available. All conformers are possible.
    '''

    # Init
    struct = ''
    for i in vec:
        try:
            # Place character
            if charset[i] != 'PAD':
                struct += charset[i]
        except:
            #raise KeyError('Invalid character encountered.')
            return None

    # Return decoded structure.
    return struct


In [77]:
# @title Fragment VAE Training
train_fragments('drive/MyDrive/Data Sets/TurboChemDB5.2.csv', 'fragment_model', epochs=100)

{'which': 'train', 'nchars': 47, 'max_length': 100, 'inputdim': 22, 'data': 'drive/MyDrive/Data Sets/TurboChemDB5.2.csv', 'output': 'fragment_model', 'labels': None, 'kernels': [9, 9, 10], 'filters': [10, 10, 11], 'embedding_dim': 32, 'latent_dim': 292, 'epsilon': 0.1, 'dropout': 0.2, 'weights': None, 'freeze_vae': False, 'validation': 0.1, 'batch_size': 128, 'epochs': 100, 'patience': 5}
Model: "encoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   


  super().__init__(name, **kwargs)


 batch_normalization_10 (Bat  (None, 22)               88        
 chNormalization)                                                
                                                                 
 dropout_10 (Dropout)        (None, 22)                0         
                                                                 
 dense_18 (Dense)            (None, 292)               6716      
                                                                 
Total params: 6,804
Trainable params: 6,760
Non-trainable params: 44
_________________________________________________________________
None
Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_32 (InputLayer)       [(None, 292)]             0         
                                                                 
 reshape_17 (Reshape)        (None, 292, 1)            0         
                                                   

  updates = self.state_updates


29181/29181 - 122s - loss: 0.7341 - accuracy: 0.8077 - categorical_crossentropy: 0.7341 - ignore_accuracy: 0.4132 - val_loss: 0.5819 - val_accuracy: 0.8199 - val_categorical_crossentropy: 0.5819 - val_ignore_accuracy: 0.4104 - 122s/epoch - 4ms/sample
Epoch 2/100
29181/29181 - 117s - loss: 0.5772 - accuracy: 0.8181 - categorical_crossentropy: 0.5772 - ignore_accuracy: 0.4212 - val_loss: 0.5519 - val_accuracy: 0.8235 - val_categorical_crossentropy: 0.5519 - val_ignore_accuracy: 0.4259 - 117s/epoch - 4ms/sample
Epoch 3/100
29181/29181 - 120s - loss: 0.5683 - accuracy: 0.8193 - categorical_crossentropy: 0.5683 - ignore_accuracy: 0.4232 - val_loss: 0.5534 - val_accuracy: 0.8218 - val_categorical_crossentropy: 0.5534 - val_ignore_accuracy: 0.4311 - 120s/epoch - 4ms/sample
Epoch 4/100
29181/29181 - 117s - loss: 0.5607 - accuracy: 0.8205 - categorical_crossentropy: 0.5607 - ignore_accuracy: 0.4246 - val_loss: 0.5434 - val_accuracy: 0.8240 - val_categorical_crossentropy: 0.5434 - val_ignore_acc

In [78]:
# @title Fragment Prediction
smiles = predict('trimmed_data.csv', 'fragment_model')
smiles_df = pd.DataFrame({'SMILES': smiles})

smiles_df.to_csv('%s_smiles.csv' % 'fragment', index=False)


which: train

nchars: 47

max_length: 100

inputdim: 22

data: drive/MyDrive/Data Sets/TurboChemDB5.2.csv

output: fragment_model

labels: -1

kernels: [9, 9, 10]

filters: [10, 10, 11]

embedding_dim: 32

latent_dim: 292

epsilon: 0.1

dropout: 0.2

weights: -1

freeze_vae: False

validation: 0.1

batch_size: 128

epochs: 100

patience: 5

{'nchars': 47, 'max_length': 100, 'inputdim': 22, 'data': 'drive/MyDrive/Data Sets/TurboChemDB5.2.csv', 'output': 'fragment_model', 'labels': None, 'kernels': [9, 9, 10], 'filters': [10, 10, 11], 'embedding_dim': 32, 'latent_dim': 292, 'dropout': 0.2, 'weights': None, 'freeze_vae': False, 'validation': 0.1, 'batch_size': 128, 'epochs': 100, 'patience': 5, 'epsilon_std': 0.1}


  super().__init__(name, **kwargs)


Model: "encoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 batch_normalization_11 (Bat  (None, 22)               88        
 chNormalization)                                                
                                                                 
 dropout_11 (Dropout)        (None, 22)                0         
                                                                 
 dense_20 (Dense)            (None, 292)               6716      
                                                                 
Total params: 6,804
Trainable params: 6,760
Non-trainable params: 44
_________________________________________________________________
None
Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_35 (InputLayer)       [(None, 292)]             0         
                                  

  updates=self.state_updates,
